This commit is contained in:
Chris Hill-Scott
2020-06-18 11:44:36 +01:00
parent f466abeea6
commit 7ff9ee40e8
3 changed files with 63 additions and 68 deletions

View File

@@ -6,13 +6,10 @@ from notifications_python_client.errors import (
from notifications_utils import request_helper
from sqlalchemy.exc import DataError
from sqlalchemy.orm.exc import NoResultFound
from gds_metrics import Histogram
from app import db
from app.dao.services_dao import dao_fetch_service_by_id
from app.dao.api_key_dao import get_model_api_keys
from app.serialised_models import (
SerialisedAPIKey,
SerialisedAPIKeyCollection,
SerialisedService,
)
@@ -20,11 +17,6 @@ from app.serialised_models import (
GENERAL_TOKEN_ERROR_MESSAGE = 'Invalid token: make sure your API token matches the example at https://docs.notifications.service.gov.uk/rest-api.html#authorisation-header' # noqa
AUTH_DB_CONNECTION_DURATION_SECONDS = Histogram(
'auth_db_connection_duration_seconds',
'Time taken to get DB connection and fetch service from database',
)
class AuthError(Exception):
def __init__(self, message, code, service_id=None, api_key_id=None):
@@ -93,30 +85,6 @@ def requires_admin_auth():
raise AuthError('Unauthorized: admin authentication token required', 401)
def get_service_dict(issuer):
from app.schemas import service_schema
with AUTH_DB_CONNECTION_DURATION_SECONDS.time():
fetched = dao_fetch_service_by_id(issuer)
return service_schema.dump(fetched).data
@SerialisedService.cache
def get_service_model(issuer):
return SerialisedService(get_service_dict(issuer))
def get_api_keys_dict(issuer):
return [
{k: getattr(key, k) for k in SerialisedAPIKey.ALLOWED_PROPERTIES}
for key in get_model_api_keys(issuer)
]
@SerialisedAPIKeyCollection.cache
def get_api_keys_models(issuer):
return SerialisedAPIKeyCollection(get_api_keys_dict(issuer))
def requires_auth():
request_helper.check_proxy_header_before_request()
@@ -124,8 +92,8 @@ def requires_auth():
issuer = __get_token_issuer(auth_token) # ie the `iss` claim which should be a service ID
try:
service = get_service_model(issuer)
service.api_keys = get_api_keys_models(issuer)
service = SerialisedService.from_id(issuer)
service.api_keys = SerialisedAPIKeyCollection.from_service_id(issuer)
db.session.commit()
except DataError:
raise AuthError("Invalid token: service id is not the right data type", 403)

View File

@@ -8,8 +8,7 @@ from notifications_utils.recipients import (
)
from notifications_utils.clients.redis import rate_limit_cache_key, daily_limit_cache_key
from app import db
from app.dao import services_dao, templates_dao
from app.dao import services_dao
from app.dao.service_sms_sender_dao import dao_get_service_sms_senders_by_id
from app.models import (
INTERNATIONAL_SMS_TYPE, SMS_TYPE, EMAIL_TYPE, LETTER_TYPE,
@@ -149,32 +148,15 @@ def check_notification_content_is_not_empty(template_with_content):
raise BadRequestError(message=message)
def get_template_dict(template_id, service_id):
from app.schemas import template_schema
def validate_template(template_id, personalisation, service, notification_type):
try:
fetched_template = templates_dao.dao_get_template_by_id_and_service_id(
template_id=template_id,
service_id=service_id
)
template = SerialisedTemplate.from_template_id_and_service_id(template_id, service.id)
except NoResultFound:
message = 'Template not found'
raise BadRequestError(message=message,
fields=[{'template': message}])
template_dict = template_schema.dump(fetched_template).data
db.session.commit()
return template_dict
@SerialisedTemplate.cache
def get_template_model(template_id, service_id):
return SerialisedTemplate(get_template_dict(template_id, service_id))
def validate_template(template_id, personalisation, service, notification_type):
template = get_template_model(template_id, service.id)
check_template_is_for_notification_type(notification_type, template.template_type)
check_template_is_active(template)

View File

@@ -5,23 +5,34 @@ from threading import RLock
import cachetools
from gds_metrics import Histogram
from app import db
from app.dao.services_dao import dao_fetch_service_by_id
from app.dao.api_key_dao import get_model_api_keys
caches = defaultdict(partial(cachetools.TTLCache, maxsize=1024, ttl=2))
locks = defaultdict(RLock)
class CacheMixin:
@classmethod
def cache(cls, func):
@cachetools.cached(cache=caches[cls.__name__], lock=locks[cls.__name__])
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
AUTH_DB_CONNECTION_DURATION_SECONDS = Histogram(
'auth_db_connection_duration_seconds',
'Time taken to get DB connection and fetch service from database',
)
class SerialisedModel(ABC, CacheMixin):
def cache(func):
@cachetools.cached(
cache=caches[func.__qualname__],
lock=locks[func.__qualname__],
)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
class SerialisedModel(ABC):
"""
A SerialisedModel takes a dictionary, typically created by
@@ -44,7 +55,7 @@ class SerialisedModel(ABC, CacheMixin):
return super().__dir__() + list(sorted(self.ALLOWED_PROPERTIES))
class SerialisedModelCollection(ABC, CacheMixin):
class SerialisedModelCollection(ABC):
"""
A SerialisedModelCollection takes a list of dictionaries, typically
@@ -68,6 +79,7 @@ class SerialisedModelCollection(ABC, CacheMixin):
class SerialisedTemplate(SerialisedModel):
ALLOWED_PROPERTIES = {
'archived',
'content',
@@ -80,6 +92,23 @@ class SerialisedTemplate(SerialisedModel):
'version',
}
@classmethod
@cache
def from_template_id_and_service_id(cls, template_id, service_id):
from app.dao.templates_dao import dao_get_template_by_id_and_service_id
from app.schemas import template_schema
fetched_template = dao_get_template_by_id_and_service_id(
template_id=template_id,
service_id=service_id
)
template_dict = template_schema.dump(fetched_template).data
db.session.commit()
return cls(template_dict)
class SerialisedService(SerialisedModel):
ALLOWED_PROPERTIES = {
@@ -92,6 +121,14 @@ class SerialisedService(SerialisedModel):
'restricted',
}
@classmethod
@cache
def from_id(cls, service_id):
from app.schemas import service_schema
with AUTH_DB_CONNECTION_DURATION_SECONDS.time():
fetched = dao_fetch_service_by_id(service_id)
return cls(service_schema.dump(fetched).data)
class SerialisedAPIKey(SerialisedModel):
ALLOWED_PROPERTIES = {
@@ -104,3 +141,11 @@ class SerialisedAPIKey(SerialisedModel):
class SerialisedAPIKeyCollection(SerialisedModelCollection):
model = SerialisedAPIKey
@classmethod
@cache
def from_service_id(cls, service_id):
return cls([
{k: getattr(key, k) for k in SerialisedAPIKey.ALLOWED_PROPERTIES}
for key in get_model_api_keys(service_id)
])