diff --git a/app/delivery/send_to_providers.py b/app/delivery/send_to_providers.py index b22855e08..71a8c8d02 100644 --- a/app/delivery/send_to_providers.py +++ b/app/delivery/send_to_providers.py @@ -1,7 +1,7 @@ import random from urllib import parse from datetime import datetime, timedelta - +from cachetools import TTLCache, cached from flask import current_app from notifications_utils.recipients import ( validate_and_format_phone_number, @@ -148,6 +148,10 @@ def update_notification_to_sending(notification, provider): dao_update_notification(notification) +provider_cache = TTLCache(maxsize=8, ttl=10) + + +@cached(cache=provider_cache) def provider_to_use(notification_type, international=False): active_providers = [ p for p in get_provider_details_by_notification_type(notification_type, international) if p.active diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index 6cf1013a4..4ddfa0563 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -34,6 +34,12 @@ from tests.app.db import ( ) +def setup_function(_function): + # pytest will run this function before each test. It makes sure the + # state of the cache is not shared between tests. + send_to_providers.provider_cache.clear() + + def test_provider_to_use_should_return_random_provider(mocker, notify_db_session): mmg = get_provider_details_by_identifier('mmg') firetext = get_provider_details_by_identifier('firetext') @@ -47,6 +53,21 @@ def test_provider_to_use_should_return_random_provider(mocker, notify_db_session assert ret.get_name() == 'mmg' +def test_provider_to_use_should_cache_repeated_calls(mocker, notify_db_session): + mock_choices = mocker.patch( + 'app.delivery.send_to_providers.random.choices', + wraps=send_to_providers.random.choices, + ) + + results = [ + send_to_providers.provider_to_use('sms', international=False) + for _ in range(10) + ] + + assert all(result == results[0] for result in results) + assert len(mock_choices.call_args_list) == 1 + + def test_provider_to_use_should_only_return_mmg_for_international(mocker, notify_db_session): mmg = get_provider_details_by_identifier('mmg') mock_choices = mocker.patch('app.delivery.send_to_providers.random.choices', return_value=[mmg])