diff --git a/app/celery/broadcast_message_tasks.py b/app/celery/broadcast_message_tasks.py index 652644636..bd5d38b47 100644 --- a/app/celery/broadcast_message_tasks.py +++ b/app/celery/broadcast_message_tasks.py @@ -48,6 +48,10 @@ def send_broadcast_provider_message(broadcast_event_id, provider): for polygon in broadcast_event.transmitted_areas["simple_polygons"] ] + channel = "test" + if broadcast_event.service.broadcast_channel: + channel = broadcast_event.service.broadcast_channel + cbc_proxy_provider_client = cbc_proxy_client.get_proxy(provider) if broadcast_event.message_type == BroadcastEventMessageType.ALERT: @@ -59,6 +63,7 @@ def send_broadcast_provider_message(broadcast_event_id, provider): areas=areas, sent=broadcast_event.sent_at_as_cap_datetime_string, expires=broadcast_event.transmitted_finishes_at_as_cap_datetime_string, + channel=channel ) elif broadcast_event.message_type == BroadcastEventMessageType.UPDATE: cbc_proxy_provider_client.update_and_send_broadcast( @@ -70,6 +75,13 @@ def send_broadcast_provider_message(broadcast_event_id, provider): previous_provider_messages=broadcast_event.get_earlier_provider_messages(provider), sent=broadcast_event.sent_at_as_cap_datetime_string, expires=broadcast_event.transmitted_finishes_at_as_cap_datetime_string, + # We think an alert update should always go out on the same channel that created the alert + # We recognise there is a small risk with this code here that if the services channel was + # changed between an alert being sent out and then updated, then something might go wrong + # but we are relying on service channels changing almost never, and not mid incident + # We may consider in the future, changing this such that we store the channel a broadcast was + # sent on on the broadcast message itself and pick the value from there instead of the service + channel=channel ) elif broadcast_event.message_type == BroadcastEventMessageType.CANCEL: cbc_proxy_provider_client.cancel_broadcast( diff --git a/app/clients/cbc_proxy.py b/app/clients/cbc_proxy.py index ea7a9539a..2686d07af 100644 --- a/app/clients/cbc_proxy.py +++ b/app/clients/cbc_proxy.py @@ -90,7 +90,7 @@ class CBCProxyClientBase(ABC): pass def create_and_send_broadcast( - self, identifier, headline, description, areas, sent, expires, message_number=None + self, identifier, headline, description, areas, sent, expires, channel, message_number=None ): pass @@ -98,11 +98,10 @@ class CBCProxyClientBase(ABC): def update_and_send_broadcast( self, identifier, previous_provider_messages, headline, description, areas, - sent, expires, message_number=None + sent, expires, channel, message_number=None ): pass - # We have not implemented cancelling a broadcast def cancel_broadcast( self, identifier, previous_provider_messages, headline, description, areas, @@ -198,7 +197,7 @@ class CBCProxyEE(CBCProxyClientBase): self._invoke_lambda_with_failover(payload=payload) def create_and_send_broadcast( - self, identifier, headline, description, areas, sent, expires, message_number=None + self, identifier, headline, description, areas, sent, expires, channel, message_number=None ): payload = { 'message_type': 'alert', @@ -210,7 +209,7 @@ class CBCProxyEE(CBCProxyClientBase): 'sent': sent, 'expires': expires, 'language': self.infer_language_from(description), - 'channel': 'test', + 'channel': channel, } self._invoke_lambda_with_failover(payload=payload) @@ -259,7 +258,7 @@ class CBCProxyThree(CBCProxyClientBase): self._invoke_lambda_with_failover(payload=payload) def create_and_send_broadcast( - self, identifier, headline, description, areas, sent, expires, message_number=None + self, identifier, headline, description, areas, sent, expires, channel, message_number=None ): payload = { 'message_type': 'alert', @@ -271,7 +270,7 @@ class CBCProxyThree(CBCProxyClientBase): 'sent': sent, 'expires': expires, 'language': self.infer_language_from(description), - 'channel': 'test', + 'channel': channel, } self._invoke_lambda_with_failover(payload=payload) @@ -319,7 +318,7 @@ class CBCProxyO2(CBCProxyClientBase): self._invoke_lambda_with_failover(payload=payload) def create_and_send_broadcast( - self, identifier, headline, description, areas, sent, expires, message_number=None + self, identifier, headline, description, areas, sent, expires, channel, message_number=None ): payload = { 'message_type': 'alert', @@ -331,7 +330,7 @@ class CBCProxyO2(CBCProxyClientBase): 'sent': sent, 'expires': expires, 'language': self.infer_language_from(description), - 'channel': 'test', + 'channel': channel, } self._invoke_lambda_with_failover(payload=payload) @@ -381,7 +380,7 @@ class CBCProxyVodafone(CBCProxyClientBase): self._invoke_lambda_with_failover(payload=payload) def create_and_send_broadcast( - self, identifier, message_number, headline, description, areas, sent, expires, + self, identifier, message_number, headline, description, areas, sent, expires, channel ): payload = { 'message_type': 'alert', @@ -394,7 +393,7 @@ class CBCProxyVodafone(CBCProxyClientBase): 'sent': sent, 'expires': expires, 'language': self.infer_language_from(description), - 'channel': 'test', + 'channel': channel, } self._invoke_lambda_with_failover(payload=payload) diff --git a/app/models.py b/app/models.py index 91e2f35b9..64b6e4575 100644 --- a/app/models.py +++ b/app/models.py @@ -505,6 +505,7 @@ class Service(db.Model, Versioned): backref=db.backref('services', lazy='dynamic')) allowed_broadcast_provider = association_proxy('service_broadcast_provider_restriction', 'provider') + broadcast_channel = association_proxy('service_broadcast_settings', 'channel') @classmethod def from_json(cls, data): @@ -2519,6 +2520,39 @@ class BroadcastProviderMessageNumber(db.Model): ) +class ServiceBroadcastSettings(db.Model): + """ + For the moment, broadcasts services CAN have a row in this table which will configure which broadcast + channel they will send to. If they don't then we will assume they should send to the test channel. + + There should only be one row per service in this table, and this is enforced by + the service_id being a primary key. + + TODO: We should enforce that every broadcast service will have a row in this table. We will need to do + this when the admin turns a service into a broadcast service, it inserts a row into this table and adds + the service permission for broadcasts for the service. Once that is up and running, we then should write + a DB migration to create rows for all broadcast services that do not have one yet in this table. + + TODO: Move functionality on the ServiceBroadcastProviderRestriction into this table and remove the + ServiceBroadcastProviderRestriction table + """ + __tablename__ = "service_broadcast_settings" + + service_id = db.Column(UUID(as_uuid=True), db.ForeignKey('services.id'), primary_key=True, nullable=False) + service = db.relationship(Service, backref=db.backref("service_broadcast_settings", uselist=False)) + channel = db.Column( + db.String(255), db.ForeignKey('broadcast_channel_types.name'), nullable=False + ) + created_at = db.Column(db.DateTime, nullable=False, default=datetime.datetime.utcnow) + updated_at = db.Column(db.DateTime, nullable=True, onupdate=datetime.datetime.utcnow) + + +class BroadcastChannelTypes(db.Model): + __tablename__ = 'broadcast_channel_types' + + name = db.Column(db.String(255), primary_key=True) + + class ServiceBroadcastProviderRestriction(db.Model): """ Most services don't send broadcasts. Of those that do, most send to all broadcast providers. diff --git a/app/schemas.py b/app/schemas.py index 4451471b8..7d9a59d06 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -271,6 +271,7 @@ class ServiceSchema(BaseSchema, UUIDsAsStringsMixin): 'reply_to_email_addresses', 'returned_letters', 'service_broadcast_provider_restriction', + 'service_broadcast_settings', 'service_notification_stats', 'service_provider_stats', 'service_sms_senders', diff --git a/migrations/versions/0342_service_broadcast_settings.py b/migrations/versions/0342_service_broadcast_settings.py new file mode 100644 index 000000000..ba706f562 --- /dev/null +++ b/migrations/versions/0342_service_broadcast_settings.py @@ -0,0 +1,43 @@ +""" + +Revision ID: 0342_service_broadcast_settings +Revises: 0341_new_letter_rates +Create Date: 2021-01-28 21:30:23.102340 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision = '0342_service_broadcast_settings' +down_revision = '0341_new_letter_rates' + +CHANNEL_TYPES = ["test", "severe"] + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('broadcast_channel_types', + sa.Column('name', sa.String(length=255), nullable=False), + sa.PrimaryKeyConstraint('name') + ) + op.create_table('service_broadcast_settings', + sa.Column('service_id', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('channel', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['channel'], ['broadcast_channel_types.name'], ), + sa.ForeignKeyConstraint(['service_id'], ['services.id'], ), + sa.PrimaryKeyConstraint('service_id') + ) + # ### end Alembic commands ### + + for channel in CHANNEL_TYPES: + op.execute(f"INSERT INTO broadcast_channel_types VALUES ('{channel}')") + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('service_broadcast_settings') + op.drop_table('broadcast_channel_types') + # ### end Alembic commands ### diff --git a/tests/app/celery/test_broadcast_message_tasks.py b/tests/app/celery/test_broadcast_message_tasks.py index 16adcc5a6..ade0d238e 100644 --- a/tests/app/celery/test_broadcast_message_tasks.py +++ b/tests/app/celery/test_broadcast_message_tasks.py @@ -9,7 +9,8 @@ from app.models import ( BroadcastStatusType, BroadcastEventMessageType, BroadcastProviderMessageStatus, - ServiceBroadcastProviderRestriction + ServiceBroadcastProviderRestriction, + ServiceBroadcastSettings, ) from app.celery.broadcast_message_tasks import send_broadcast_event, send_broadcast_provider_message, trigger_link_test @@ -107,6 +108,8 @@ def test_send_broadcast_event_does_nothing_if_cbc_proxy_disabled(mocker, notify_ @freeze_time('2020-08-01 12:00') @pytest.mark.parametrize('provider,provider_capitalised', [ ['ee', 'EE'], + ['three', 'Three'], + ['o2', 'O2'], ['vodafone', 'Vodafone'], ]) def test_send_broadcast_provider_message_sends_data_correctly( @@ -153,12 +156,104 @@ def test_send_broadcast_provider_message_sends_data_correctly( }], sent=event.sent_at_as_cap_datetime_string, expires=event.transmitted_finishes_at_as_cap_datetime_string, + channel="test", ) @freeze_time('2020-08-01 12:00') +@pytest.mark.parametrize('provider,provider_capitalised', [ + ['ee', 'EE'], + ['three', 'Three'], + ['o2', 'O2'], + ['vodafone', 'Vodafone'], +]) +@pytest.mark.parametrize('channel', ['test', 'severe']) +def test_send_broadcast_provider_message_uses_channel_set_on_broadcast_service( + notify_db, mocker, sample_service, provider, provider_capitalised, channel +): + template = create_template(sample_service, BROADCAST_TYPE) + broadcast_message = create_broadcast_message( + template, + areas={ + 'areas': ['london', 'glasgow'], + 'simple_polygons': [ + [[50.12, 1.2], [50.13, 1.2], [50.14, 1.21]], + [[-4.53, 55.72], [-3.88, 55.72], [-3.88, 55.96], [-4.53, 55.96]], + ], + }, + status=BroadcastStatusType.BROADCASTING + ) + event = create_broadcast_event(broadcast_message) + notify_db.session.add(ServiceBroadcastSettings(service=sample_service, channel=channel)) + + mock_create_broadcast = mocker.patch( + f'app.clients.cbc_proxy.CBCProxy{provider_capitalised}.create_and_send_broadcast', + ) + + send_broadcast_provider_message(provider=provider, broadcast_event_id=str(event.id)) + + mock_create_broadcast.assert_called_once_with( + identifier=mocker.ANY, + message_number=mocker.ANY, + headline='GOV.UK Notify Broadcast', + description='this is an emergency broadcast message', + areas=mocker.ANY, + sent=mocker.ANY, + expires=mocker.ANY, + channel=channel, + ) + + +@freeze_time('2020-08-01 12:00') +@pytest.mark.parametrize('provider,provider_capitalised', [ + ['ee', 'EE'], + ['three', 'Three'], + ['o2', 'O2'], + ['vodafone', 'Vodafone'], +]) +def test_send_broadcast_provider_message_defaults_to_test_channel_if_no_service_broadcast_settings( + notify_db, mocker, sample_service, provider, provider_capitalised +): + template = create_template(sample_service, BROADCAST_TYPE) + broadcast_message = create_broadcast_message( + template, + areas={ + 'areas': ['london', 'glasgow'], + 'simple_polygons': [ + [[50.12, 1.2], [50.13, 1.2], [50.14, 1.21]], + [[-4.53, 55.72], [-3.88, 55.72], [-3.88, 55.96], [-4.53, 55.96]], + ], + }, + status=BroadcastStatusType.BROADCASTING + ) + event = create_broadcast_event(broadcast_message) + mock_create_broadcast = mocker.patch( + f'app.clients.cbc_proxy.CBCProxy{provider_capitalised}.create_and_send_broadcast', + ) + + send_broadcast_provider_message(provider=provider, broadcast_event_id=str(event.id)) + + mock_create_broadcast.assert_called_once_with( + identifier=mocker.ANY, + message_number=mocker.ANY, + headline='GOV.UK Notify Broadcast', + description='this is an emergency broadcast message', + areas=mocker.ANY, + sent=mocker.ANY, + expires=mocker.ANY, + channel="test", + ) + + +@freeze_time('2020-08-01 12:00') +@pytest.mark.parametrize('provider,provider_capitalised', [ + ['ee', 'EE'], + ['three', 'Three'], + ['o2', 'O2'], + ['vodafone', 'Vodafone'], +]) def test_send_broadcast_provider_message_sends_data_correctly_when_broadcast_message_has_no_template( - mocker, sample_service, + mocker, sample_service, provider, provider_capitalised ): broadcast_message = create_broadcast_message( service=sample_service, @@ -176,12 +271,12 @@ def test_send_broadcast_provider_message_sends_data_correctly_when_broadcast_mes event = create_broadcast_event(broadcast_message) mock_create_broadcast = mocker.patch( - f'app.clients.cbc_proxy.CBCProxyEE.create_and_send_broadcast', + f'app.clients.cbc_proxy.CBCProxy{provider_capitalised}.create_and_send_broadcast', ) - send_broadcast_provider_message(provider='ee', broadcast_event_id=str(event.id)) + send_broadcast_provider_message(provider=provider, broadcast_event_id=str(event.id)) - broadcast_provider_message = event.get_provider_message('ee') + broadcast_provider_message = event.get_provider_message(provider) mock_create_broadcast.assert_called_once_with( identifier=str(broadcast_provider_message.id), @@ -191,11 +286,14 @@ def test_send_broadcast_provider_message_sends_data_correctly_when_broadcast_mes areas=mocker.ANY, sent=mocker.ANY, expires=mocker.ANY, + channel="test" ) @pytest.mark.parametrize('provider,provider_capitalised', [ ['ee', 'EE'], + ['three', 'Three'], + ['o2', 'O2'], ['vodafone', 'Vodafone'], ]) def test_send_broadcast_provider_message_sends_update_with_references( @@ -240,11 +338,14 @@ def test_send_broadcast_provider_message_sends_update_with_references( ], sent=update_event.sent_at_as_cap_datetime_string, expires=update_event.transmitted_finishes_at_as_cap_datetime_string, + channel="test" ) @pytest.mark.parametrize('provider,provider_capitalised', [ ['ee', 'EE'], + ['three', 'Three'], + ['o2', 'O2'], ['vodafone', 'Vodafone'], ]) def test_send_broadcast_provider_message_sends_cancel_with_references( @@ -290,7 +391,13 @@ def test_send_broadcast_provider_message_sends_cancel_with_references( ) -def test_send_broadcast_provider_message_errors(mocker, sample_service): +@pytest.mark.parametrize("provider,provider_capitalised", [ + ['ee', 'EE'], + ['three', 'Three'], + ['o2', 'O2'], + ['vodafone', 'Vodafone'], +]) +def test_send_broadcast_provider_message_errors(mocker, sample_service, provider, provider_capitalised): template = create_template(sample_service, BROADCAST_TYPE) broadcast_message = create_broadcast_message( @@ -307,12 +414,12 @@ def test_send_broadcast_provider_message_errors(mocker, sample_service): event = create_broadcast_event(broadcast_message) mock_create_broadcast = mocker.patch( - 'app.clients.cbc_proxy.CBCProxyEE.create_and_send_broadcast', + f'app.clients.cbc_proxy.CBCProxy{provider_capitalised}.create_and_send_broadcast', side_effect=Exception('oh no'), ) with pytest.raises(Exception) as ex: - send_broadcast_provider_message(provider='ee', broadcast_event_id=str(event.id)) + send_broadcast_provider_message(provider=provider, broadcast_event_id=str(event.id)) assert ex.match('oh no') @@ -330,12 +437,15 @@ def test_send_broadcast_provider_message_errors(mocker, sample_service): }], sent=event.sent_at_as_cap_datetime_string, expires=event.transmitted_finishes_at_as_cap_datetime_string, + channel="test" ) @pytest.mark.parametrize("provider,provider_capitalised", [ - ["ee", "EE"], - ["vodafone", "Vodafone"] + ['ee', 'EE'], + ['three', 'Three'], + ['o2', 'O2'], + ['vodafone', 'Vodafone'], ]) def test_trigger_link_tests_invokes_cbc_proxy_client( mocker, provider, provider_capitalised diff --git a/tests/app/clients/test_cbc_proxy.py b/tests/app/clients/test_cbc_proxy.py index 59ef33a8e..f6618d606 100644 --- a/tests/app/clients/test_cbc_proxy.py +++ b/tests/app/clients/test_cbc_proxy.py @@ -110,7 +110,9 @@ def test_cbc_proxy_one_2_many_create_and_send_invokes_function( headline=headline, description=description, areas=EXAMPLE_AREAS, - sent=sent, expires=expires, + sent=sent, + expires=expires, + channel="severe", ) ld_client_mock.invoke.assert_called_once_with( @@ -133,7 +135,7 @@ def test_cbc_proxy_one_2_many_create_and_send_invokes_function( assert payload['sent'] == sent assert payload['expires'] == expires assert payload['language'] == expected_language - assert payload['channel'] == 'test' + assert payload['channel'] == 'severe' @pytest.mark.parametrize('cbc', ['ee', 'three', 'o2']) @@ -227,7 +229,9 @@ def test_cbc_proxy_vodafone_create_and_send_invokes_function( headline=headline, description=description, areas=EXAMPLE_AREAS, - sent=sent, expires=expires, + sent=sent, + expires=expires, + channel="test", ) ld_client_mock.invoke.assert_called_once_with( @@ -348,6 +352,7 @@ def test_cbc_proxy_will_failover_to_second_lambda_if_function_error( areas=EXAMPLE_AREAS, sent='a-passed-through-sent-value', expires='a-passed-through-expires-value', + channel="severe", ) assert ld_client_mock.invoke.call_args_list == [ @@ -395,6 +400,7 @@ def test_cbc_proxy_will_failover_to_second_lambda_if_invoke_error( areas=EXAMPLE_AREAS, sent='a-passed-through-sent-value', expires='a-passed-through-expires-value', + channel="test", ) assert ld_client_mock.invoke.call_args_list == [ @@ -436,6 +442,7 @@ def test_cbc_proxy_create_and_send_tries_failover_lambda_on_invoke_error_and_rai areas=EXAMPLE_AREAS, sent='a-passed-through-sent-value', expires='a-passed-through-expires-value', + channel="test", ) assert e.match(f'Lambda failed for both {cbc}-1-proxy and {cbc}-2-proxy') @@ -484,6 +491,7 @@ def test_cbc_proxy_create_and_send_tries_failover_lambda_on_function_error_and_r areas=EXAMPLE_AREAS, sent='a-passed-through-sent-value', expires='a-passed-through-expires-value', + channel="severe", ) assert e.match(f'Lambda failed for both {cbc}-1-proxy and {cbc}-2-proxy') diff --git a/tests/app/service/test_rest.py b/tests/app/service/test_rest.py index 6aac914ea..937518a2a 100644 --- a/tests/app/service/test_rest.py +++ b/tests/app/service/test_rest.py @@ -36,7 +36,6 @@ from app.models import ( INBOUND_SMS_TYPE, NOTIFICATION_RETURNED_LETTER, UPLOAD_LETTERS, - ) from tests import create_authorization_header from tests.app.db import ( diff --git a/tests/conftest.py b/tests/conftest.py index 420f40afd..fa7e3d17d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,7 +119,8 @@ def notify_db_session(notify_db, sms_providers): "auth_type", "broadcast_status_type", "invite_status_type", - "service_callback_type"]: + "service_callback_type", + "broadcast_channel_types"]: notify_db.engine.execute(tbl.delete()) notify_db.session.commit()