diff --git a/app/authentication/auth.py b/app/authentication/auth.py index 458e43c38..f1bc003b8 100644 --- a/app/authentication/auth.py +++ b/app/authentication/auth.py @@ -6,13 +6,10 @@ from notifications_python_client.errors import ( from notifications_utils import request_helper from sqlalchemy.exc import DataError from sqlalchemy.orm.exc import NoResultFound -from gds_metrics import Histogram from app import db from app.dao.services_dao import dao_fetch_service_by_id -from app.dao.api_key_dao import get_model_api_keys from app.serialised_models import ( - SerialisedAPIKey, SerialisedAPIKeyCollection, SerialisedService, ) @@ -20,11 +17,6 @@ from app.serialised_models import ( GENERAL_TOKEN_ERROR_MESSAGE = 'Invalid token: make sure your API token matches the example at https://docs.notifications.service.gov.uk/rest-api.html#authorisation-header' # noqa -AUTH_DB_CONNECTION_DURATION_SECONDS = Histogram( - 'auth_db_connection_duration_seconds', - 'Time taken to get DB connection and fetch service from database', -) - class AuthError(Exception): def __init__(self, message, code, service_id=None, api_key_id=None): @@ -93,30 +85,6 @@ def requires_admin_auth(): raise AuthError('Unauthorized: admin authentication token required', 401) -def get_service_dict(issuer): - from app.schemas import service_schema - with AUTH_DB_CONNECTION_DURATION_SECONDS.time(): - fetched = dao_fetch_service_by_id(issuer) - return service_schema.dump(fetched).data - - -@SerialisedService.cache -def get_service_model(issuer): - return SerialisedService(get_service_dict(issuer)) - - -def get_api_keys_dict(issuer): - return [ - {k: getattr(key, k) for k in SerialisedAPIKey.ALLOWED_PROPERTIES} - for key in get_model_api_keys(issuer) - ] - - -@SerialisedAPIKeyCollection.cache -def get_api_keys_models(issuer): - return SerialisedAPIKeyCollection(get_api_keys_dict(issuer)) - - def requires_auth(): request_helper.check_proxy_header_before_request() @@ -124,8 +92,8 @@ def requires_auth(): issuer = __get_token_issuer(auth_token) # ie the `iss` claim which should be a service ID try: - service = get_service_model(issuer) - service.api_keys = get_api_keys_models(issuer) + service = SerialisedService.from_id(issuer) + service.api_keys = SerialisedAPIKeyCollection.from_service_id(issuer) db.session.commit() except DataError: raise AuthError("Invalid token: service id is not the right data type", 403) diff --git a/app/notifications/validators.py b/app/notifications/validators.py index f3d6cac62..06615ca6b 100644 --- a/app/notifications/validators.py +++ b/app/notifications/validators.py @@ -8,8 +8,7 @@ from notifications_utils.recipients import ( ) from notifications_utils.clients.redis import rate_limit_cache_key, daily_limit_cache_key -from app import db -from app.dao import services_dao, templates_dao +from app.dao import services_dao from app.dao.service_sms_sender_dao import dao_get_service_sms_senders_by_id from app.models import ( INTERNATIONAL_SMS_TYPE, SMS_TYPE, EMAIL_TYPE, LETTER_TYPE, @@ -149,32 +148,15 @@ def check_notification_content_is_not_empty(template_with_content): raise BadRequestError(message=message) -def get_template_dict(template_id, service_id): - from app.schemas import template_schema +def validate_template(template_id, personalisation, service, notification_type): + try: - fetched_template = templates_dao.dao_get_template_by_id_and_service_id( - template_id=template_id, - service_id=service_id - ) + template = SerialisedTemplate.from_template_id_and_service_id(template_id, service.id) except NoResultFound: message = 'Template not found' raise BadRequestError(message=message, fields=[{'template': message}]) - template_dict = template_schema.dump(fetched_template).data - - db.session.commit() - return template_dict - - -@SerialisedTemplate.cache -def get_template_model(template_id, service_id): - return SerialisedTemplate(get_template_dict(template_id, service_id)) - - -def validate_template(template_id, personalisation, service, notification_type): - template = get_template_model(template_id, service.id) - check_template_is_for_notification_type(notification_type, template.template_type) check_template_is_active(template) diff --git a/app/serialised_models.py b/app/serialised_models.py index 33d75f13d..5b39faa86 100644 --- a/app/serialised_models.py +++ b/app/serialised_models.py @@ -5,23 +5,34 @@ from threading import RLock import cachetools +from gds_metrics import Histogram + +from app import db +from app.dao.services_dao import dao_fetch_service_by_id +from app.dao.api_key_dao import get_model_api_keys + caches = defaultdict(partial(cachetools.TTLCache, maxsize=1024, ttl=2)) locks = defaultdict(RLock) -class CacheMixin: - - @classmethod - def cache(cls, func): - - @cachetools.cached(cache=caches[cls.__name__], lock=locks[cls.__name__]) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - return wrapper +AUTH_DB_CONNECTION_DURATION_SECONDS = Histogram( + 'auth_db_connection_duration_seconds', + 'Time taken to get DB connection and fetch service from database', +) -class SerialisedModel(ABC, CacheMixin): +def cache(func): + @cachetools.cached( + cache=caches[func.__qualname__], + lock=locks[func.__qualname__], + ) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +class SerialisedModel(ABC): """ A SerialisedModel takes a dictionary, typically created by @@ -44,7 +55,7 @@ class SerialisedModel(ABC, CacheMixin): return super().__dir__() + list(sorted(self.ALLOWED_PROPERTIES)) -class SerialisedModelCollection(ABC, CacheMixin): +class SerialisedModelCollection(ABC): """ A SerialisedModelCollection takes a list of dictionaries, typically @@ -68,6 +79,7 @@ class SerialisedModelCollection(ABC, CacheMixin): class SerialisedTemplate(SerialisedModel): + ALLOWED_PROPERTIES = { 'archived', 'content', @@ -80,6 +92,23 @@ class SerialisedTemplate(SerialisedModel): 'version', } + @classmethod + @cache + def from_template_id_and_service_id(cls, template_id, service_id): + + from app.dao.templates_dao import dao_get_template_by_id_and_service_id + from app.schemas import template_schema + + fetched_template = dao_get_template_by_id_and_service_id( + template_id=template_id, + service_id=service_id + ) + + template_dict = template_schema.dump(fetched_template).data + + db.session.commit() + return cls(template_dict) + class SerialisedService(SerialisedModel): ALLOWED_PROPERTIES = { @@ -92,6 +121,14 @@ class SerialisedService(SerialisedModel): 'restricted', } + @classmethod + @cache + def from_id(cls, service_id): + from app.schemas import service_schema + with AUTH_DB_CONNECTION_DURATION_SECONDS.time(): + fetched = dao_fetch_service_by_id(service_id) + return cls(service_schema.dump(fetched).data) + class SerialisedAPIKey(SerialisedModel): ALLOWED_PROPERTIES = { @@ -104,3 +141,11 @@ class SerialisedAPIKey(SerialisedModel): class SerialisedAPIKeyCollection(SerialisedModelCollection): model = SerialisedAPIKey + + @classmethod + @cache + def from_service_id(cls, service_id): + return cls([ + {k: getattr(key, k) for k in SerialisedAPIKey.ALLOWED_PROPERTIES} + for key in get_model_api_keys(service_id) + ])