diff --git a/app/__init__.py b/app/__init__.py index c412e4b8f..5f85ccad6 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -23,7 +23,7 @@ from werkzeug.local import LocalProxy from app.celery.celery import NotifyCelery from app.clients import NotificationProviderClients -from app.clients.cbc_proxy import CBCProxyClient, CBCProxyNoopClient +from app.clients.cbc_proxy import CBCProxyClient from app.clients.document_download import DocumentDownloadClient from app.clients.email.aws_ses import AwsSesClient from app.clients.email.aws_ses_stub import AwsSesStubClient @@ -61,7 +61,7 @@ zendesk_client = ZendeskClient() statsd_client = StatsdClient() redis_store = RedisClient() performance_platform_client = PerformancePlatformClient() -cbc_proxy_client = CBCProxyNoopClient() +cbc_proxy_client = CBCProxyClient() document_download_client = DocumentDownloadClient() metrics = GDSMetrics() @@ -114,9 +114,6 @@ def create_app(application): performance_platform_client.init_app(application) document_download_client.init_app(application) - global cbc_proxy_client - if application.config['CBC_PROXY_AWS_ACCESS_KEY_ID']: - cbc_proxy_client = CBCProxyClient() cbc_proxy_client.init_app(application) register_blueprint(application) diff --git a/app/celery/broadcast_message_tasks.py b/app/celery/broadcast_message_tasks.py index c90c3e297..65fdceeba 100644 --- a/app/celery/broadcast_message_tasks.py +++ b/app/celery/broadcast_message_tasks.py @@ -1,17 +1,33 @@ +import uuid + from flask import current_app from notifications_utils.statsd_decorators import statsd from app import cbc_proxy_client, notify_celery - +from app.config import QueueNames from app.models import BroadcastEventMessageType -from app.dao.broadcast_message_dao import dao_get_broadcast_event_by_id +from app.dao.broadcast_message_dao import dao_get_broadcast_event_by_id, create_broadcast_provider_message @notify_celery.task(name="send-broadcast-event") @statsd(namespace="tasks") def send_broadcast_event(broadcast_event_id): + for provider in current_app.config['ENABLED_CBCS']: + # TODO: Decide whether to send to each provider based on platform admin, service level settings, broadcast + # level settings, etc. + send_broadcast_provider_message.apply_async( + kwargs={'broadcast_event_id': broadcast_event_id, 'provider': provider}, + queue=QueueNames.NOTIFY + ) + + +@notify_celery.task(name="send-broadcast-provider-message") +@statsd(namespace="tasks") +def send_broadcast_provider_message(broadcast_event_id, provider): broadcast_event = dao_get_broadcast_event_by_id(broadcast_event_id) + broadcast_provider_message = create_broadcast_provider_message(broadcast_event, provider) + current_app.logger.info( f'invoking cbc proxy to send ' f'broadcast_event {broadcast_event.reference} ' @@ -23,9 +39,11 @@ def send_broadcast_event(broadcast_event_id): for polygon in broadcast_event.transmitted_areas["simple_polygons"] ] + cbc_proxy_provider_client = cbc_proxy_client.get_proxy(provider) + if broadcast_event.message_type == BroadcastEventMessageType.ALERT: - cbc_proxy_client.create_and_send_broadcast( - identifier=str(broadcast_event.id), + cbc_proxy_provider_client.create_and_send_broadcast( + identifier=str(broadcast_provider_message.id), headline="GOV.UK Notify Broadcast", description=broadcast_event.transmitted_content['body'], areas=areas, @@ -33,22 +51,43 @@ def send_broadcast_event(broadcast_event_id): expires=broadcast_event.transmitted_finishes_at_as_cap_datetime_string, ) elif broadcast_event.message_type == BroadcastEventMessageType.UPDATE: - cbc_proxy_client.update_and_send_broadcast( - identifier=str(broadcast_event.id), + cbc_proxy_provider_client.update_and_send_broadcast( + identifier=str(broadcast_provider_message.id), headline="GOV.UK Notify Broadcast", description=broadcast_event.transmitted_content['body'], areas=areas, - references=broadcast_event.get_earlier_message_references(), + previous_provider_messages=broadcast_event.get_earlier_provider_messages(provider), sent=broadcast_event.sent_at_as_cap_datetime_string, expires=broadcast_event.transmitted_finishes_at_as_cap_datetime_string, ) elif broadcast_event.message_type == BroadcastEventMessageType.CANCEL: - cbc_proxy_client.cancel_broadcast( - identifier=str(broadcast_event.id), + cbc_proxy_provider_client.cancel_broadcast( + identifier=str(broadcast_provider_message.id), headline="GOV.UK Notify Broadcast", description=broadcast_event.transmitted_content['body'], areas=areas, - references=broadcast_event.get_earlier_message_references(), + previous_provider_messages=broadcast_event.get_earlier_provider_messages(provider), sent=broadcast_event.sent_at_as_cap_datetime_string, expires=broadcast_event.transmitted_finishes_at_as_cap_datetime_string, ) + + +@notify_celery.task(name='trigger-link-test') +def trigger_link_test(provider): + """ + Currently we only have one hardcoded CBC Proxy, which corresponds to one + CBC, and so currently we do not specify the CBC Proxy name + + In future we will have multiple CBC proxies, each proxy corresponding to + one MNO's CBC + + This task should invoke other tasks which do the actual link tests, eg: + for cbc_name in app.config.ENABLED_CBCS: + send_link_test_for_cbc(cbc_name) + + Alternatively this task could be configured to be a Celery group + """ + identifier = str(uuid.uuid4()) + message = f"Sending a link test to CBC proxy for provider {provider} with ID {identifier}" + current_app.logger.info(message) + cbc_proxy_client.get_proxy(provider).send_link_test(identifier) diff --git a/app/celery/scheduled_tasks.py b/app/celery/scheduled_tasks.py index eaf898ddf..7fc16a08f 100644 --- a/app/celery/scheduled_tasks.py +++ b/app/celery/scheduled_tasks.py @@ -17,6 +17,7 @@ from app.celery.tasks import ( process_row, process_incomplete_jobs) from app.celery.letters_pdf_tasks import get_pdf_for_templated_letter +from app.celery.broadcast_message_tasks import trigger_link_test from app.config import QueueNames from app.dao.invited_org_user_dao import delete_org_invitations_created_more_than_two_days_ago from app.dao.invited_user_dao import delete_invitations_created_more_than_two_days_ago @@ -300,26 +301,10 @@ def send_canary_to_cbc_proxy(): identifier = str(uuid.uuid4()) message = f"Sending a canary message to CBC proxy with ID {identifier}" current_app.logger.info(message) - cbc_proxy_client.send_canary(identifier) + cbc_proxy_client.get_proxy('canary').send_canary(identifier) @notify_celery.task(name='trigger-link-tests') def trigger_link_tests(): - """ - Currently we only have one hardcoded CBC Proxy, which corresponds to one - CBC, and so currently we do not specify the CBC Proxy name - - In future we will have multiple CBC proxies, each proxy corresponding to - one MNO's CBC - - This task should invoke other tasks which do the actual link tests, eg: - for cbc_name in app.config.ENABLED_CBCS: - send_link_test_for_cbc(cbc_name) - - Alternatively this task could be configured to be a Celery group - """ - for _ in range(1): - identifier = str(uuid.uuid4()) - message = f"Sending a link test to CBC proxy with ID {identifier}" - current_app.logger.info(message) - cbc_proxy_client.send_link_test(identifier) + for cbc_name in current_app.config['ENABLED_CBCS']: + trigger_link_test.apply_async(kwargs={'provider': cbc_name}, queue=QueueNames.NOTIFY) diff --git a/app/clients/cbc_proxy.py b/app/clients/cbc_proxy.py index 3a82f640e..99174121b 100644 --- a/app/clients/cbc_proxy.py +++ b/app/clients/cbc_proxy.py @@ -1,6 +1,9 @@ import json import boto3 +from flask import current_app + +from app.config import BroadcastProvider # The variable names in this file have specific meaning in a CAP message # @@ -14,21 +17,40 @@ import boto3 # * description is a string which populates the areaDesc field # * polygon is a list of lat/long pairs # -# references is a whitespace separated list of message identifiers -# where each identifier is a previous sent message -# ie a Cancel message would have a unique identifier but have the identifier of -# the preceeding Alert message in the references field +# previous_provider_messages is a list of previous events (models.py::BroadcastProviderMessage) +# ie a Cancel message would have a unique event but have the event of +# the preceeding Alert message in the previous_provider_messages field class CBCProxyException(Exception): pass -# Noop = no operation -class CBCProxyNoopClient: +class CBCProxyClient: + _lambda_client = None def init_app(self, app): - pass + if app.config.get('CBC_PROXY_AWS_ACCESS_KEY_ID'): + self._lambda_client = boto3.client( + 'lambda', + region_name='eu-west-2', + aws_access_key_id=app.config['CBC_PROXY_AWS_ACCESS_KEY_ID'], + aws_secret_access_key=app.config['CBC_PROXY_AWS_SECRET_ACCESS_KEY'], + ) + + def get_proxy(self, provider): + proxy_classes = { + 'canary': CBCProxyCanary, + BroadcastProvider.EE: CBCProxyEE, + } + return proxy_classes[provider](self._lambda_client) + + +class CBCProxyClientBase: + lambda_name = None + + def __init__(self, lambda_client): + self._lambda_client = lambda_client def send_canary( self, @@ -52,7 +74,7 @@ class CBCProxyNoopClient: # We have not implementated updating a broadcast def update_and_send_broadcast( self, - identifier, references, headline, description, areas, + identifier, previous_provider_messages, headline, description, areas, sent, expires, ): pass @@ -60,27 +82,22 @@ class CBCProxyNoopClient: # We have not implemented cancelling a broadcast def cancel_broadcast( self, - identifier, references, headline, description, areas, + identifier, previous_provider_messages, headline, description, areas, sent, expires, ): pass + def _invoke_lambda(self, payload): + if not self.lambda_name: + current_app.logger.warning( + '{self.__class__.__name__} tried to send {payload} but cbc proxy aws env vars not set' + ) + return -class CBCProxyClient: - - def init_app(self, app): - self._lambda_client = boto3.client( - 'lambda', - region_name='eu-west-2', - aws_access_key_id=app.config['CBC_PROXY_AWS_ACCESS_KEY_ID'], - aws_secret_access_key=app.config['CBC_PROXY_AWS_SECRET_ACCESS_KEY'], - ) - - def _invoke_lambda(self, function_name, payload): payload_bytes = bytes(json.dumps(payload), encoding='utf8') result = self._lambda_client.invoke( - FunctionName=function_name, + FunctionName=self.lambda_name, InvocationType='RequestResponse', Payload=payload_bytes, ) @@ -93,19 +110,35 @@ class CBCProxyClient: return result + +class CBCProxyCanary(CBCProxyClientBase): + """ + The canary is a lambda which tests notify's connectivity to the Cell Broadcast AWS infrastructure. It calls the + canary, a specific lambda that does not open a vpn or connect to a provider but just responds from within AWS. + """ + lambda_name = 'canary' + def send_canary( self, identifier, ): - self._invoke_lambda(function_name='canary', payload={'identifier': identifier}) + self._invoke_lambda(payload={'identifier': identifier}) + + +class CBCProxyEE(CBCProxyClientBase): + lambda_name = 'bt-ee-1-proxy' def send_link_test( self, identifier, ): + """ + link test - open up a connection to a specific provider, and send them an xml payload with a of + test. + """ payload = {'message_type': 'test', 'identifier': identifier} - self._invoke_lambda(function_name='bt-ee-1-proxy', payload=payload) + self._invoke_lambda(payload=payload) def create_and_send_broadcast( self, @@ -121,21 +154,4 @@ class CBCProxyClient: 'sent': sent, 'expires': expires, } - - self._invoke_lambda(function_name='bt-ee-1-proxy', payload=payload) - - # We have not implementated updating a broadcast - def update_and_send_broadcast( - self, - identifier, references, headline, description, areas, - sent, expires, - ): - pass - - # We have not implemented cancelling a broadcast - def cancel_broadcast( - self, - identifier, references, headline, description, areas, - sent, expires, - ): - pass + self._invoke_lambda(payload=payload) diff --git a/app/config.py b/app/config.py index 3809a3109..c6871a6b5 100644 --- a/app/config.py +++ b/app/config.py @@ -56,6 +56,15 @@ class QueueNames(object): ] +class BroadcastProvider: + EE = 'ee' + VODAFONE = 'vodafone' + THREE = 'three' + O2 = 'o2' + + PROVIDERS = [EE, VODAFONE, THREE, O2] + + class TaskNames(object): PROCESS_INCOMPLETE_JOBS = 'process-incomplete-jobs' ZIP_AND_SEND_LETTER_PDFS = 'zip-and-send-letter-pdfs' @@ -367,6 +376,8 @@ class Config(object): CBC_PROXY_AWS_ACCESS_KEY_ID = os.environ.get('CBC_PROXY_AWS_ACCESS_KEY_ID', '') CBC_PROXY_AWS_SECRET_ACCESS_KEY = os.environ.get('CBC_PROXY_AWS_SECRET_ACCESS_KEY', '') + ENABLED_CBCS = {BroadcastProvider.EE} + ###################### # Config overrides ### diff --git a/app/dao/broadcast_message_dao.py b/app/dao/broadcast_message_dao.py index 0c7f560a7..c24e93997 100644 --- a/app/dao/broadcast_message_dao.py +++ b/app/dao/broadcast_message_dao.py @@ -1,4 +1,6 @@ -from app.models import BroadcastMessage, BroadcastEvent +from app import db +from app.dao.dao_utils import transactional +from app.models import BroadcastMessage, BroadcastEvent, BroadcastProviderMessage, BroadcastProviderMessageStatus def dao_get_broadcast_message_by_id_and_service_id(broadcast_message_id, service_id): @@ -34,3 +36,14 @@ def get_earlier_events_for_broadcast_event(broadcast_event_id): ).order_by( BroadcastEvent.sent_at.asc() ).all() + + +@transactional +def create_broadcast_provider_message(broadcast_event, provider): + provider_message = BroadcastProviderMessage( + broadcast_event=broadcast_event, + provider=provider, + status=BroadcastProviderMessageStatus.SENDING, + ) + db.session.add(provider_message) + return provider_message diff --git a/app/models.py b/app/models.py index db4a2b3d0..1d5ea5174 100644 --- a/app/models.py +++ b/app/models.py @@ -2195,6 +2195,10 @@ class BroadcastStatusType(db.Model): class BroadcastMessage(db.Model): + """ + This is for creating a message, viewing it in notify, adding areas, approvals, drafts, etc. Notify logic before + hitting send. + """ __tablename__ = 'broadcast_message' __table_args__ = ( db.ForeignKeyConstraint( @@ -2299,7 +2303,8 @@ class BroadcastEventMessageType: class BroadcastEvent(db.Model): """ - This table represents a single CAP XML blob that we sent to the mobile network providers. + This table represents an instruction that we will send to the broadcast providers. It directly correlates with an + instruction from the admin - to broadcast a message, to cancel an existing message, or to update an existing one. We should be able to create the complete CAP message without joining from this to any other tables, eg template, service, or broadcast_message. @@ -2372,9 +2377,40 @@ class BroadcastEvent(db.Model): """ return f"{dt.strftime('%Y-%m-%dT%H:%M:%S')}-00:00" - def get_earlier_message_references(self): + def get_provider_message(self, provider): + return next( + ( + provider_message + for provider_message in self.provider_messages + if provider_message.provider == provider + ), + None + ) + + def get_earlier_provider_messages(self, provider): + """ + Get the previous message for a provider. These are differentper provider, as the identifiers are different. + Return the full provider_message object rather than just an identifier, since the different providers expect + reference to contain different things - let the cbc_proxy work out what information is relevant. + """ from app.dao.broadcast_message_dao import get_earlier_events_for_broadcast_event - return [event.reference for event in get_earlier_events_for_broadcast_event(self.id)] + earlier_events = [ + event for event in get_earlier_events_for_broadcast_event(self.id) + ] + ret = [] + for event in earlier_events: + provider_message = event.get_provider_message(provider) + if provider_message is None: + # TODO: We should figure out what to do if a previous message hasn't been sent out yet. + # We don't want to not cancel a message just because it's stuck in a queue somewhere. + # This exception should probably be named, and then should be caught further up and handled + # appropriately. + raise Exception( + f'Cannot get earlier message references for event {self.id}, previous event {event.id} has not ' + + f' been sent to provider "{provider}" yet' + ) + ret.append(provider_message) + return ret def serialize(self): return { @@ -2382,8 +2418,6 @@ class BroadcastEvent(db.Model): 'service_id': str(self.service_id), - 'previous_event_references': self.get_earlier_message_references(), - 'broadcast_message_id': str(self.broadcast_message_id), # sent_at is required by BroadcastMessageTemplate.from_broadcast_event 'sent_at': self.sent_at.strftime(DATETIME_FORMAT), @@ -2398,3 +2432,43 @@ class BroadcastEvent(db.Model): 'transmitted_finishes_at': self.transmitted_finishes_at.strftime(DATETIME_FORMAT), } + + +class BroadcastProvider: + EE = 'ee' + VODAFONE = 'vodafone' + THREE = 'three' + O2 = 'o2' + + PROVIDERS = [EE, VODAFONE, THREE, O2] + + +class BroadcastProviderMessageStatus: + TECHNICAL_FAILURE = 'technical-failure' # Couldn’t send (cbc proxy 5xx/4xx) + SENDING = 'sending' # Sent to cbc, awaiting response + ACK = 'returned-ack' # Received ack response + ERR = 'returned-error' # Received error response + + STATES = [TECHNICAL_FAILURE, SENDING, ACK, ERR] + + +class BroadcastProviderMessage(db.Model): + """ + A row in this table represents the XML blob sent to a single provider. + """ + __tablename__ = 'broadcast_provider_message' + + id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + + broadcast_event_id = db.Column(UUID(as_uuid=True), db.ForeignKey('broadcast_event.id')) + broadcast_event = db.relationship('BroadcastEvent', backref='provider_messages') + + # 'ee', 'three', 'vodafone', etc + provider = db.Column(db.String) + + status = db.Column(db.String) + + created_at = db.Column(db.DateTime, nullable=False, default=datetime.datetime.utcnow) + updated_at = db.Column(db.DateTime, nullable=True, onupdate=datetime.datetime.utcnow) + + UniqueConstraint(broadcast_event_id, provider) diff --git a/migrations/versions/0332_broadcast_provider_msg.py b/migrations/versions/0332_broadcast_provider_msg.py new file mode 100644 index 000000000..088f1c9df --- /dev/null +++ b/migrations/versions/0332_broadcast_provider_msg.py @@ -0,0 +1,49 @@ +""" + +Revision ID: 0332_broadcast_provider_msg +Revises: 0331_add_broadcast_org +Create Date: 2020-10-26 16:28:11.917468 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = '0332_broadcast_provider_msg' +down_revision = '0331_add_broadcast_org' + +STATUSES = [ + 'technical-failure', + 'sending', + 'returned-ack', + 'returned-error', +] + + +def upgrade(): + + broadcast_provider_message_status_type = op.create_table( + 'broadcast_provider_message_status_type', + sa.Column('name', sa.String(), nullable=False), + sa.PrimaryKeyConstraint('name') + ) + op.bulk_insert(broadcast_provider_message_status_type, [{'name': status} for status in STATUSES]) + + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'broadcast_provider_message', + sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('broadcast_event_id', postgresql.UUID(as_uuid=True), nullable=True), + sa.Column('provider', sa.String(), nullable=True), + sa.Column('status', sa.String(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['broadcast_event_id'], ['broadcast_event.id'], ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('broadcast_event_id', 'provider') + ) + + +def downgrade(): + op.drop_table('broadcast_provider_message') + op.drop_table('broadcast_provider_message_status_type') diff --git a/tests/app/celery/test_broadcast_message_tasks.py b/tests/app/celery/test_broadcast_message_tasks.py index d6b5cdb4e..097f940a0 100644 --- a/tests/app/celery/test_broadcast_message_tasks.py +++ b/tests/app/celery/test_broadcast_message_tasks.py @@ -1,13 +1,38 @@ +import uuid +from unittest.mock import call, ANY + from freezegun import freeze_time import pytest -from app.models import BROADCAST_TYPE, BroadcastStatusType, BroadcastEventMessageType -from app.celery.broadcast_message_tasks import send_broadcast_event -from tests.app.db import create_template, create_broadcast_message, create_broadcast_event +from app.models import BROADCAST_TYPE, BroadcastStatusType, BroadcastEventMessageType, BroadcastProviderMessageStatus +from app.celery.broadcast_message_tasks import send_broadcast_event, send_broadcast_provider_message, trigger_link_test + +from tests.app.db import ( + create_template, + create_broadcast_message, + create_broadcast_event, + create_broadcast_provider_message +) +from tests.conftest import set_config + + +def test_send_broadcast_event_queues_up_for_active_providers(mocker, notify_api): + mock_send_broadcast_provider_message = mocker.patch( + 'app.celery.broadcast_message_tasks.send_broadcast_provider_message', + ) + + event_id = uuid.uuid4() + with set_config(notify_api, 'ENABLED_CBCS', ['ee', 'vodafone']): + send_broadcast_event(event_id) + + assert mock_send_broadcast_provider_message.apply_async.call_args_list == [ + call(kwargs={'broadcast_event_id': event_id, 'provider': 'ee'}, queue='notify-internal-tasks'), + call(kwargs={'broadcast_event_id': event_id, 'provider': 'vodafone'}, queue='notify-internal-tasks') + ] @freeze_time('2020-08-01 12:00') -def test_create_broadcast_event_sends_data_correctly(mocker, sample_service): +def test_send_broadcast_provider_message_sends_data_correctly(mocker, sample_service): template = create_template(sample_service, BROADCAST_TYPE) broadcast_message = create_broadcast_message( template, @@ -23,13 +48,18 @@ def test_create_broadcast_event_sends_data_correctly(mocker, sample_service): event = create_broadcast_event(broadcast_message) mock_create_broadcast = mocker.patch( - 'app.cbc_proxy_client.create_and_send_broadcast', + 'app.clients.cbc_proxy.CBCProxyEE.create_and_send_broadcast', ) - send_broadcast_event(broadcast_event_id=str(event.id)) + assert event.get_provider_message('ee') is None + + send_broadcast_provider_message(provider='ee', broadcast_event_id=str(event.id)) + + broadcast_provider_message = event.get_provider_message('ee') + assert broadcast_provider_message.status == BroadcastProviderMessageStatus.SENDING mock_create_broadcast.assert_called_once_with( - identifier=str(event.id), + identifier=str(broadcast_provider_message.id), headline='GOV.UK Notify Broadcast', description='this is an emergency broadcast message', areas=[{ @@ -46,7 +76,7 @@ def test_create_broadcast_event_sends_data_correctly(mocker, sample_service): ) -def test_update_broadcast_event_sends_references(mocker, sample_service): +def test_send_broadcast_provider_message_sends_update_with_references(mocker, sample_service): template = create_template(sample_service, BROADCAST_TYPE, content='content') broadcast_message = create_broadcast_message( @@ -61,28 +91,34 @@ def test_update_broadcast_event_sends_references(mocker, sample_service): ) alert_event = create_broadcast_event(broadcast_message, message_type=BroadcastEventMessageType.ALERT) + create_broadcast_provider_message(alert_event, 'ee') update_event = create_broadcast_event(broadcast_message, message_type=BroadcastEventMessageType.UPDATE) mock_update_broadcast = mocker.patch( - 'app.cbc_proxy_client.update_and_send_broadcast', + 'app.clients.cbc_proxy.CBCProxyEE.update_and_send_broadcast', ) - send_broadcast_event(broadcast_event_id=str(update_event.id)) + send_broadcast_provider_message(provider='ee', broadcast_event_id=str(update_event.id)) + + broadcast_provider_message = update_event.get_provider_message('ee') + assert broadcast_provider_message.status == BroadcastProviderMessageStatus.SENDING mock_update_broadcast.assert_called_once_with( - identifier=str(update_event.id), + identifier=str(broadcast_provider_message.id), headline="GOV.UK Notify Broadcast", description='this is an emergency broadcast message', areas=[{ "polygon": [[50.12, 1.2], [50.13, 1.2], [50.14, 1.21]], }], - references=[alert_event.reference], + previous_provider_messages=[ + alert_event.get_provider_message('ee') + ], sent=update_event.sent_at_as_cap_datetime_string, expires=update_event.transmitted_finishes_at_as_cap_datetime_string, ) -def test_cancel_broadcast_event_sends_references(mocker, sample_service): +def test_send_broadcast_provider_message_sends_cancel_with_references(mocker, sample_service): template = create_template(sample_service, BROADCAST_TYPE, content='content') broadcast_message = create_broadcast_message( @@ -100,26 +136,35 @@ def test_cancel_broadcast_event_sends_references(mocker, sample_service): update_event = create_broadcast_event(broadcast_message, message_type=BroadcastEventMessageType.UPDATE) cancel_event = create_broadcast_event(broadcast_message, message_type=BroadcastEventMessageType.CANCEL) + create_broadcast_provider_message(alert_event, 'ee') + create_broadcast_provider_message(update_event, 'ee') + mock_cancel_broadcast = mocker.patch( - 'app.cbc_proxy_client.cancel_broadcast', + 'app.clients.cbc_proxy.CBCProxyEE.cancel_broadcast', ) - send_broadcast_event(broadcast_event_id=str(cancel_event.id)) + send_broadcast_provider_message(provider='ee', broadcast_event_id=str(cancel_event.id)) + + broadcast_provider_message = cancel_event.get_provider_message('ee') + assert broadcast_provider_message.status == BroadcastProviderMessageStatus.SENDING mock_cancel_broadcast.assert_called_once_with( - identifier=str(cancel_event.id), + identifier=str(broadcast_provider_message.id), headline="GOV.UK Notify Broadcast", description='this is an emergency broadcast message', areas=[{ "polygon": [[50.12, 1.2], [50.13, 1.2], [50.14, 1.21]], }], - references=[alert_event.reference, update_event.reference], + previous_provider_messages=[ + alert_event.get_provider_message('ee'), + update_event.get_provider_message('ee') + ], sent=cancel_event.sent_at_as_cap_datetime_string, expires=cancel_event.transmitted_finishes_at_as_cap_datetime_string, ) -def test_send_broadcast_event_errors(mocker, sample_service): +def test_send_broadcast_provider_message_errors(mocker, sample_service): template = create_template(sample_service, BROADCAST_TYPE) broadcast_message = create_broadcast_message( @@ -136,17 +181,17 @@ def test_send_broadcast_event_errors(mocker, sample_service): event = create_broadcast_event(broadcast_message) mock_create_broadcast = mocker.patch( - 'app.cbc_proxy_client.create_and_send_broadcast', + 'app.clients.cbc_proxy.CBCProxyEE.create_and_send_broadcast', side_effect=Exception('oh no'), ) with pytest.raises(Exception) as ex: - send_broadcast_event(broadcast_event_id=str(event.id)) + send_broadcast_provider_message(provider='ee', broadcast_event_id=str(event.id)) assert ex.match('oh no') mock_create_broadcast.assert_called_once_with( - identifier=str(event.id), + identifier=ANY, headline="GOV.UK Notify Broadcast", description='this is an emergency broadcast message', areas=[{ @@ -159,3 +204,22 @@ def test_send_broadcast_event_errors(mocker, sample_service): sent=event.sent_at_as_cap_datetime_string, expires=event.transmitted_finishes_at_as_cap_datetime_string, ) + + +def test_trigger_link_tests_invokes_cbc_proxy_client( + mocker, +): + mock_send_link_test = mocker.patch( + 'app.clients.cbc_proxy.CBCProxyEE.send_link_test', + ) + + trigger_link_test('ee') + + assert mock_send_link_test.called + # the 0th argument of the call to send_link_test + identifier = mock_send_link_test.mock_calls[0][1][0] + + try: + uuid.UUID(identifier) + except BaseException: + pytest.fail(f"{identifier} is not a valid uuid") diff --git a/tests/app/celery/test_scheduled_tasks.py b/tests/app/celery/test_scheduled_tasks.py index c76632c2e..8ea0613a6 100644 --- a/tests/app/celery/test_scheduled_tasks.py +++ b/tests/app/celery/test_scheduled_tasks.py @@ -19,6 +19,7 @@ from app.celery.scheduled_tasks import ( check_for_missing_rows_in_completed_jobs, check_for_services_with_high_failure_rates_or_sending_to_tv_numbers, switch_current_sms_provider_on_slow_delivery, + trigger_link_tests, ) from app.config import QueueNames, Config from app.dao.jobs_dao import dao_get_job_by_id @@ -30,8 +31,8 @@ from app.models import ( NOTIFICATION_DELIVERED, NOTIFICATION_PENDING_VIRUS_CHECK, ) +from tests.conftest import set_config from tests.app import load_example_csv - from tests.app.db import ( create_notification, create_template, @@ -560,9 +561,10 @@ def test_check_for_services_with_high_failure_rates_or_sending_to_tv_numbers( def test_send_canary_to_cbc_proxy_invokes_cbc_proxy_client( mocker, + notify_api ): mock_send_canary = mocker.patch( - 'app.cbc_proxy_client.send_canary', + 'app.clients.cbc_proxy.CBCProxyCanary.send_canary', ) scheduled_tasks.send_canary_to_cbc_proxy() @@ -577,20 +579,17 @@ def test_send_canary_to_cbc_proxy_invokes_cbc_proxy_client( pytest.fail(f"{identifier} is not a valid uuid") -def test_trigger_link_tests_invokes_cbc_proxy_client( - mocker, +def test_trigger_link_tests_calls_for_all_providers( + mocker, notify_api ): - mock_send_link_test = mocker.patch( - 'app.cbc_proxy_client.send_link_test', + mock_trigger_link_test = mocker.patch( + 'app.celery.scheduled_tasks.trigger_link_test', ) - scheduled_tasks.trigger_link_tests() + with set_config(notify_api, 'ENABLED_CBCS', ['ee', 'vodafone']): + trigger_link_tests() - mock_send_link_test.assert_called - # the 0th argument of the call to send_link_test - identifier = mock_send_link_test.mock_calls[0][1][0] - - try: - uuid.UUID(identifier) - except BaseException: - pytest.fail(f"{identifier} is not a valid uuid") + assert mock_trigger_link_test.apply_async.call_args_list == [ + call(kwargs={'provider': 'ee'}, queue='notify-internal-tasks'), + call(kwargs={'provider': 'vodafone'}, queue='notify-internal-tasks') + ] diff --git a/tests/app/clients/test_cbc_proxy.py b/tests/app/clients/test_cbc_proxy.py index d888d44b6..10534855b 100644 --- a/tests/app/clients/test_cbc_proxy.py +++ b/tests/app/clients/test_cbc_proxy.py @@ -1,13 +1,14 @@ import json import uuid +from unittest.mock import Mock import pytest -from app.clients.cbc_proxy import CBCProxyClient, CBCProxyException +from app.clients.cbc_proxy import CBCProxyClient, CBCProxyException, CBCProxyEE, CBCProxyCanary @pytest.fixture(scope='function') -def cbc_proxy(client, mocker): +def cbc_proxy_client(client, mocker): client = CBCProxyClient() current_app = mocker.Mock(config={ 'CBC_PROXY_AWS_ACCESS_KEY_ID': 'cbc-proxy-aws-access-key-id', @@ -17,19 +18,39 @@ def cbc_proxy(client, mocker): return client -def test_cbc_proxy_lambda_client_has_correct_region(cbc_proxy): - assert cbc_proxy._lambda_client._client_config.region_name == 'eu-west-2' +@pytest.fixture +def cbc_proxy_ee(cbc_proxy_client): + return cbc_proxy_client.get_proxy('ee') -def test_cbc_proxy_lambda_client_has_correct_keys(cbc_proxy): - key = cbc_proxy._lambda_client._request_signer._credentials.access_key - secret = cbc_proxy._lambda_client._request_signer._credentials.secret_key +@pytest.mark.parametrize('provider_name, expected_provider_class', [ + ('ee', CBCProxyEE), + ('canary', CBCProxyCanary), +]) +def test_cbc_proxy_client_returns_correct_client(provider_name, expected_provider_class): + mock_lambda = Mock() + cbc_proxy_client = CBCProxyClient() + cbc_proxy_client._lambda_client = mock_lambda + + ret = cbc_proxy_client.get_proxy(provider_name) + + assert type(ret) == expected_provider_class + assert ret._lambda_client == mock_lambda + + +def test_cbc_proxy_lambda_client_has_correct_region(cbc_proxy_ee): + assert cbc_proxy_ee._lambda_client._client_config.region_name == 'eu-west-2' + + +def test_cbc_proxy_lambda_client_has_correct_keys(cbc_proxy_ee): + key = cbc_proxy_ee._lambda_client._request_signer._credentials.access_key + secret = cbc_proxy_ee._lambda_client._request_signer._credentials.secret_key assert key == 'cbc-proxy-aws-access-key-id' assert secret == 'cbc-proxy-aws-secret-access-key' -def test_cbc_proxy_create_and_send_invokes_function(mocker, cbc_proxy): +def test_cbc_proxy_create_and_send_invokes_function(mocker, cbc_proxy_ee): identifier = 'my-identifier' headline = 'my-headline' description = 'my-description' @@ -50,7 +71,7 @@ def test_cbc_proxy_create_and_send_invokes_function(mocker, cbc_proxy): }] ld_client_mock = mocker.patch.object( - cbc_proxy, + cbc_proxy_ee, '_lambda_client', create=True, ) @@ -59,7 +80,7 @@ def test_cbc_proxy_create_and_send_invokes_function(mocker, cbc_proxy): 'StatusCode': 200, } - cbc_proxy.create_and_send_broadcast( + cbc_proxy_ee.create_and_send_broadcast( identifier=identifier, headline=headline, description=description, @@ -86,7 +107,7 @@ def test_cbc_proxy_create_and_send_invokes_function(mocker, cbc_proxy): assert payload['expires'] == expires -def test_cbc_proxy_create_and_send_handles_invoke_error(mocker, cbc_proxy): +def test_cbc_proxy_create_and_send_handles_invoke_error(mocker, cbc_proxy_ee): identifier = 'my-identifier' headline = 'my-headline' description = 'my-description' @@ -107,7 +128,7 @@ def test_cbc_proxy_create_and_send_handles_invoke_error(mocker, cbc_proxy): }] ld_client_mock = mocker.patch.object( - cbc_proxy, + cbc_proxy_ee, '_lambda_client', create=True, ) @@ -117,7 +138,7 @@ def test_cbc_proxy_create_and_send_handles_invoke_error(mocker, cbc_proxy): } with pytest.raises(CBCProxyException) as e: - cbc_proxy.create_and_send_broadcast( + cbc_proxy_ee.create_and_send_broadcast( identifier=identifier, headline=headline, description=description, @@ -134,7 +155,7 @@ def test_cbc_proxy_create_and_send_handles_invoke_error(mocker, cbc_proxy): ) -def test_cbc_proxy_create_and_send_handles_function_error(mocker, cbc_proxy): +def test_cbc_proxy_create_and_send_handles_function_error(mocker, cbc_proxy_ee): identifier = 'my-identifier' headline = 'my-headline' description = 'my-description' @@ -155,7 +176,7 @@ def test_cbc_proxy_create_and_send_handles_function_error(mocker, cbc_proxy): }] ld_client_mock = mocker.patch.object( - cbc_proxy, + cbc_proxy_ee, '_lambda_client', create=True, ) @@ -166,7 +187,7 @@ def test_cbc_proxy_create_and_send_handles_function_error(mocker, cbc_proxy): } with pytest.raises(CBCProxyException) as e: - cbc_proxy.create_and_send_broadcast( + cbc_proxy_ee.create_and_send_broadcast( identifier=identifier, headline=headline, description=description, @@ -183,11 +204,13 @@ def test_cbc_proxy_create_and_send_handles_function_error(mocker, cbc_proxy): ) -def test_cbc_proxy_send_canary_invokes_function(mocker, cbc_proxy): +def test_cbc_proxy_send_canary_invokes_function(mocker, cbc_proxy_client): identifier = str(uuid.uuid4()) + canary_client = cbc_proxy_client.get_proxy('canary') + ld_client_mock = mocker.patch.object( - cbc_proxy, + canary_client, '_lambda_client', create=True, ) @@ -196,7 +219,7 @@ def test_cbc_proxy_send_canary_invokes_function(mocker, cbc_proxy): 'StatusCode': 200, } - cbc_proxy.send_canary( + canary_client.send_canary( identifier=identifier, ) @@ -213,66 +236,11 @@ def test_cbc_proxy_send_canary_invokes_function(mocker, cbc_proxy): assert payload['identifier'] == identifier -def test_cbc_proxy_send_canary_handles_invoke_error(mocker, cbc_proxy): +def test_cbc_proxy_send_link_test_invokes_function(mocker, cbc_proxy_ee): identifier = str(uuid.uuid4()) ld_client_mock = mocker.patch.object( - cbc_proxy, - '_lambda_client', - create=True, - ) - - ld_client_mock.invoke.return_value = { - 'StatusCode': 400, - } - - with pytest.raises(CBCProxyException) as e: - cbc_proxy.send_canary( - identifier=identifier, - ) - - assert e.match('Could not invoke lambda') - - ld_client_mock.invoke.assert_called_once_with( - FunctionName='canary', - InvocationType='RequestResponse', - Payload=mocker.ANY, - ) - - -def test_cbc_proxy_send_canary_handles_function_error(mocker, cbc_proxy): - identifier = str(uuid.uuid4()) - - ld_client_mock = mocker.patch.object( - cbc_proxy, - '_lambda_client', - create=True, - ) - - ld_client_mock.invoke.return_value = { - 'StatusCode': 200, - 'FunctionError': 'something', - } - - with pytest.raises(CBCProxyException) as e: - cbc_proxy.send_canary( - identifier=identifier, - ) - - assert e.match('Function exited with unhandled exception') - - ld_client_mock.invoke.assert_called_once_with( - FunctionName='canary', - InvocationType='RequestResponse', - Payload=mocker.ANY, - ) - - -def test_cbc_proxy_send_link_test_invokes_function(mocker, cbc_proxy): - identifier = str(uuid.uuid4()) - - ld_client_mock = mocker.patch.object( - cbc_proxy, + cbc_proxy_ee, '_lambda_client', create=True, ) @@ -281,7 +249,7 @@ def test_cbc_proxy_send_link_test_invokes_function(mocker, cbc_proxy): 'StatusCode': 200, } - cbc_proxy.send_link_test( + cbc_proxy_ee.send_link_test( identifier=identifier, ) @@ -297,58 +265,3 @@ def test_cbc_proxy_send_link_test_invokes_function(mocker, cbc_proxy): assert payload['identifier'] == identifier assert payload['message_type'] == 'test' - - -def test_cbc_proxy_send_link_test_handles_invoke_error(mocker, cbc_proxy): - identifier = str(uuid.uuid4()) - - ld_client_mock = mocker.patch.object( - cbc_proxy, - '_lambda_client', - create=True, - ) - - ld_client_mock.invoke.return_value = { - 'StatusCode': 400, - } - - with pytest.raises(CBCProxyException) as e: - cbc_proxy.send_link_test( - identifier=identifier, - ) - - assert e.match('Could not invoke lambda') - - ld_client_mock.invoke.assert_called_once_with( - FunctionName='bt-ee-1-proxy', - InvocationType='RequestResponse', - Payload=mocker.ANY, - ) - - -def test_cbc_proxy_send_link_test_handles_function_error(mocker, cbc_proxy): - identifier = str(uuid.uuid4()) - - ld_client_mock = mocker.patch.object( - cbc_proxy, - '_lambda_client', - create=True, - ) - - ld_client_mock.invoke.return_value = { - 'StatusCode': 200, - 'FunctionError': 'something', - } - - with pytest.raises(CBCProxyException) as e: - cbc_proxy.send_link_test( - identifier=identifier, - ) - - assert e.match('Function exited with unhandled exception') - - ld_client_mock.invoke.assert_called_once_with( - FunctionName='bt-ee-1-proxy', - InvocationType='RequestResponse', - Payload=mocker.ANY, - ) diff --git a/tests/app/dao/test_broadcast_message_dao.py b/tests/app/dao/test_broadcast_message_dao.py index cefebbbd7..a8fb160b5 100644 --- a/tests/app/dao/test_broadcast_message_dao.py +++ b/tests/app/dao/test_broadcast_message_dao.py @@ -1,7 +1,7 @@ from datetime import datetime -from app.models import BROADCAST_TYPE -from app.models import BroadcastEventMessageType -from app.dao.broadcast_message_dao import get_earlier_events_for_broadcast_event + +from app.models import BROADCAST_TYPE, BroadcastEventMessageType +from app.dao.broadcast_message_dao import get_earlier_events_for_broadcast_event, create_broadcast_provider_message from tests.app.db import create_broadcast_message, create_template, create_broadcast_event @@ -41,3 +41,21 @@ def test_get_earlier_events_for_broadcast_event(sample_service): # only fetches earlier events, and they're in time order earlier_events = get_earlier_events_for_broadcast_event(events[2].id) assert earlier_events == [events[0], events[1]] + + +def test_create_broadcast_provider_message_creates_in_correct_state(sample_broadcast_service): + t = create_template(sample_broadcast_service, BROADCAST_TYPE) + broadcast_message = create_broadcast_message(t) + broadcast_event = create_broadcast_event( + broadcast_message, + sent_at=datetime(2020, 1, 1, 12, 0, 0), + message_type=BroadcastEventMessageType.ALERT, + transmitted_content={'body': 'Initial content'} + ) + + broadcast_provider_message = create_broadcast_provider_message(broadcast_event, 'fake-provider') + + assert broadcast_provider_message.status == 'sending' + assert broadcast_provider_message.broadcast_event_id == broadcast_event.id + assert broadcast_provider_message.created_at is not None + assert broadcast_provider_message.updated_at is None diff --git a/tests/app/db.py b/tests/app/db.py index 969a20e28..30f00e7b6 100644 --- a/tests/app/db.py +++ b/tests/app/db.py @@ -62,7 +62,8 @@ from app.models import ( ServiceContactList, BroadcastMessage, BroadcastStatusType, - BroadcastEvent + BroadcastEvent, + BroadcastProviderMessage ) @@ -1050,3 +1051,18 @@ def create_broadcast_event( db.session.add(b_e) db.session.commit() return b_e + + +def create_broadcast_provider_message( + broadcast_event, + provider, + status='sending' +): + provider_message = BroadcastProviderMessage( + broadcast_event=broadcast_event, + provider=provider, + status=status + ) + db.session.add(provider_message) + db.session.commit() + return provider_message