diff --git a/app/delivery/send_to_providers.py b/app/delivery/send_to_providers.py index 23d7f4f0e..e761956f3 100644 --- a/app/delivery/send_to_providers.py +++ b/app/delivery/send_to_providers.py @@ -171,7 +171,12 @@ def provider_to_use(notification_type, international=False): ) raise Exception("No active {} providers".format(notification_type)) - chosen_provider = random.choices(active_providers, weights=[p.priority for p in active_providers])[0] + if len(active_providers) == 1: + weights = [100] + else: + weights = [p.priority for p in active_providers] + + chosen_provider = random.choices(active_providers, weights=weights)[0] return notification_provider_clients.get_client_by_name_and_type(chosen_provider.identifier, notification_type) diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index 16e8a817c..cf1ef1552 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -72,8 +72,18 @@ def test_provider_to_use_should_cache_repeated_calls(mocker, notify_db_session): assert len(mock_choices.call_args_list) == 1 -def test_provider_to_use_should_only_return_mmg_for_international(mocker, notify_db_session): +@pytest.mark.parametrize('international_provider_priority', ( + # Since there’s only one international provider it should always + # be used, no matter what its priority is set to + 0, 50, 100, +)) +def test_provider_to_use_should_only_return_mmg_for_international( + mocker, + notify_db_session, + international_provider_priority, +): mmg = get_provider_details_by_identifier('mmg') + mmg.priority = international_provider_priority mock_choices = mocker.patch('app.delivery.send_to_providers.random.choices', return_value=[mmg]) ret = send_to_providers.provider_to_use('sms', international=True) @@ -90,7 +100,7 @@ def test_provider_to_use_should_only_return_active_providers(mocker, restore_pro ret = send_to_providers.provider_to_use('sms') - mock_choices.assert_called_once_with([firetext], weights=[0]) + mock_choices.assert_called_once_with([firetext], weights=[100]) assert ret.get_name() == 'firetext'