diff --git a/app/authentication/auth.py b/app/authentication/auth.py index e61990749..458e43c38 100644 --- a/app/authentication/auth.py +++ b/app/authentication/auth.py @@ -100,6 +100,7 @@ def get_service_dict(issuer): return service_schema.dump(fetched).data +@SerialisedService.cache def get_service_model(issuer): return SerialisedService(get_service_dict(issuer)) @@ -111,6 +112,7 @@ def get_api_keys_dict(issuer): ] +@SerialisedAPIKeyCollection.cache def get_api_keys_models(issuer): return SerialisedAPIKeyCollection(get_api_keys_dict(issuer)) diff --git a/app/notifications/validators.py b/app/notifications/validators.py index ea939a59e..f3d6cac62 100644 --- a/app/notifications/validators.py +++ b/app/notifications/validators.py @@ -167,6 +167,7 @@ def get_template_dict(template_id, service_id): return template_dict +@SerialisedTemplate.cache def get_template_model(template_id, service_id): return SerialisedTemplate(get_template_dict(template_id, service_id)) diff --git a/app/serialised_models.py b/app/serialised_models.py index 37cf4e363..33d75f13d 100644 --- a/app/serialised_models.py +++ b/app/serialised_models.py @@ -1,7 +1,27 @@ from abc import ABC, abstractmethod +from collections import defaultdict +from functools import partial +from threading import RLock + +import cachetools + +caches = defaultdict(partial(cachetools.TTLCache, maxsize=1024, ttl=2)) +locks = defaultdict(RLock) -class SerialisedModel(ABC): +class CacheMixin: + + @classmethod + def cache(cls, func): + + @cachetools.cached(cache=caches[cls.__name__], lock=locks[cls.__name__]) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +class SerialisedModel(ABC, CacheMixin): """ A SerialisedModel takes a dictionary, typically created by @@ -24,7 +44,7 @@ class SerialisedModel(ABC): return super().__dir__() + list(sorted(self.ALLOWED_PROPERTIES)) -class SerialisedModelCollection(ABC): +class SerialisedModelCollection(ABC, CacheMixin): """ A SerialisedModelCollection takes a list of dictionaries, typically diff --git a/tests/app/authentication/test_authentication.py b/tests/app/authentication/test_authentication.py index 4cae2acd1..14ed2dde3 100644 --- a/tests/app/authentication/test_authentication.py +++ b/tests/app/authentication/test_authentication.py @@ -8,11 +8,13 @@ import pytest from flask import json, current_app, request from freezegun import freeze_time from notifications_python_client.authentication import create_jwt_token +from unittest.mock import call from app import api_user from app.dao.api_key_dao import get_unsigned_secrets, save_model_api_key, get_unsigned_secret, expire_api_key from app.models import ApiKey, KEY_TYPE_NORMAL from app.authentication.auth import AuthError, requires_admin_auth, requires_auth, GENERAL_TOKEN_ERROR_MESSAGE +from app.serialised_models import caches, get_model_api_keys, dao_fetch_service_by_id from tests.conftest import set_config @@ -457,3 +459,34 @@ def test_proxy_key_on_admin_auth_endpoint(notify_api, check_proxy_header, header ] ) assert response.status_code == expected_status + + +def test_should_cache_service_and_api_key_lookups(mocker, client, sample_api_key): + + mock_get_api_keys = mocker.patch( + 'app.serialised_models.get_model_api_keys', + wraps=get_model_api_keys, + ) + mock_get_service = mocker.patch( + 'app.serialised_models.dao_fetch_service_by_id', + wraps=dao_fetch_service_by_id, + ) + + for i in range(5): + token = __create_token(sample_api_key.service_id) + client.get('/notifications', headers={ + 'Authorization': f'Bearer {token}' + }) + + assert mock_get_api_keys.call_args_list == [ + call(str(sample_api_key.service_id)) + ] + assert mock_get_service.call_args_list == [ + call(str(sample_api_key.service_id)) + ] + + assert caches['SerialisedService.from_id'].currsize == 1 + assert caches['SerialisedService.from_id'].ttl == 2 + + assert caches['SerialisedAPIKeyCollection.from_service_id'].currsize == 1 + assert caches['SerialisedAPIKeyCollection.from_service_id'].ttl == 2