mirror of
https://github.com/GSA/notifications-api.git
synced 2026-02-28 22:09:44 -05:00
Refactor
This commit is contained in:
@@ -6,13 +6,10 @@ from notifications_python_client.errors import (
|
||||
from notifications_utils import request_helper
|
||||
from sqlalchemy.exc import DataError
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
from gds_metrics import Histogram
|
||||
|
||||
from app import db
|
||||
from app.dao.services_dao import dao_fetch_service_by_id
|
||||
from app.dao.api_key_dao import get_model_api_keys
|
||||
from app.serialised_models import (
|
||||
SerialisedAPIKey,
|
||||
SerialisedAPIKeyCollection,
|
||||
SerialisedService,
|
||||
)
|
||||
@@ -20,11 +17,6 @@ from app.serialised_models import (
|
||||
|
||||
GENERAL_TOKEN_ERROR_MESSAGE = 'Invalid token: make sure your API token matches the example at https://docs.notifications.service.gov.uk/rest-api.html#authorisation-header' # noqa
|
||||
|
||||
AUTH_DB_CONNECTION_DURATION_SECONDS = Histogram(
|
||||
'auth_db_connection_duration_seconds',
|
||||
'Time taken to get DB connection and fetch service from database',
|
||||
)
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
def __init__(self, message, code, service_id=None, api_key_id=None):
|
||||
@@ -93,30 +85,6 @@ def requires_admin_auth():
|
||||
raise AuthError('Unauthorized: admin authentication token required', 401)
|
||||
|
||||
|
||||
def get_service_dict(issuer):
|
||||
from app.schemas import service_schema
|
||||
with AUTH_DB_CONNECTION_DURATION_SECONDS.time():
|
||||
fetched = dao_fetch_service_by_id(issuer)
|
||||
return service_schema.dump(fetched).data
|
||||
|
||||
|
||||
@SerialisedService.cache
|
||||
def get_service_model(issuer):
|
||||
return SerialisedService(get_service_dict(issuer))
|
||||
|
||||
|
||||
def get_api_keys_dict(issuer):
|
||||
return [
|
||||
{k: getattr(key, k) for k in SerialisedAPIKey.ALLOWED_PROPERTIES}
|
||||
for key in get_model_api_keys(issuer)
|
||||
]
|
||||
|
||||
|
||||
@SerialisedAPIKeyCollection.cache
|
||||
def get_api_keys_models(issuer):
|
||||
return SerialisedAPIKeyCollection(get_api_keys_dict(issuer))
|
||||
|
||||
|
||||
def requires_auth():
|
||||
request_helper.check_proxy_header_before_request()
|
||||
|
||||
@@ -124,8 +92,8 @@ def requires_auth():
|
||||
issuer = __get_token_issuer(auth_token) # ie the `iss` claim which should be a service ID
|
||||
|
||||
try:
|
||||
service = get_service_model(issuer)
|
||||
service.api_keys = get_api_keys_models(issuer)
|
||||
service = SerialisedService.from_id(issuer)
|
||||
service.api_keys = SerialisedAPIKeyCollection.from_service_id(issuer)
|
||||
db.session.commit()
|
||||
except DataError:
|
||||
raise AuthError("Invalid token: service id is not the right data type", 403)
|
||||
|
||||
@@ -8,8 +8,7 @@ from notifications_utils.recipients import (
|
||||
)
|
||||
from notifications_utils.clients.redis import rate_limit_cache_key, daily_limit_cache_key
|
||||
|
||||
from app import db
|
||||
from app.dao import services_dao, templates_dao
|
||||
from app.dao import services_dao
|
||||
from app.dao.service_sms_sender_dao import dao_get_service_sms_senders_by_id
|
||||
from app.models import (
|
||||
INTERNATIONAL_SMS_TYPE, SMS_TYPE, EMAIL_TYPE, LETTER_TYPE,
|
||||
@@ -149,32 +148,15 @@ def check_notification_content_is_not_empty(template_with_content):
|
||||
raise BadRequestError(message=message)
|
||||
|
||||
|
||||
def get_template_dict(template_id, service_id):
|
||||
from app.schemas import template_schema
|
||||
def validate_template(template_id, personalisation, service, notification_type):
|
||||
|
||||
try:
|
||||
fetched_template = templates_dao.dao_get_template_by_id_and_service_id(
|
||||
template_id=template_id,
|
||||
service_id=service_id
|
||||
)
|
||||
template = SerialisedTemplate.from_template_id_and_service_id(template_id, service.id)
|
||||
except NoResultFound:
|
||||
message = 'Template not found'
|
||||
raise BadRequestError(message=message,
|
||||
fields=[{'template': message}])
|
||||
|
||||
template_dict = template_schema.dump(fetched_template).data
|
||||
|
||||
db.session.commit()
|
||||
return template_dict
|
||||
|
||||
|
||||
@SerialisedTemplate.cache
|
||||
def get_template_model(template_id, service_id):
|
||||
return SerialisedTemplate(get_template_dict(template_id, service_id))
|
||||
|
||||
|
||||
def validate_template(template_id, personalisation, service, notification_type):
|
||||
template = get_template_model(template_id, service.id)
|
||||
|
||||
check_template_is_for_notification_type(notification_type, template.template_type)
|
||||
check_template_is_active(template)
|
||||
|
||||
|
||||
@@ -5,23 +5,34 @@ from threading import RLock
|
||||
|
||||
import cachetools
|
||||
|
||||
from gds_metrics import Histogram
|
||||
|
||||
from app import db
|
||||
from app.dao.services_dao import dao_fetch_service_by_id
|
||||
from app.dao.api_key_dao import get_model_api_keys
|
||||
|
||||
caches = defaultdict(partial(cachetools.TTLCache, maxsize=1024, ttl=2))
|
||||
locks = defaultdict(RLock)
|
||||
|
||||
|
||||
class CacheMixin:
|
||||
|
||||
@classmethod
|
||||
def cache(cls, func):
|
||||
|
||||
@cachetools.cached(cache=caches[cls.__name__], lock=locks[cls.__name__])
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
AUTH_DB_CONNECTION_DURATION_SECONDS = Histogram(
|
||||
'auth_db_connection_duration_seconds',
|
||||
'Time taken to get DB connection and fetch service from database',
|
||||
)
|
||||
|
||||
|
||||
class SerialisedModel(ABC, CacheMixin):
|
||||
def cache(func):
|
||||
@cachetools.cached(
|
||||
cache=caches[func.__qualname__],
|
||||
lock=locks[func.__qualname__],
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class SerialisedModel(ABC):
|
||||
|
||||
"""
|
||||
A SerialisedModel takes a dictionary, typically created by
|
||||
@@ -44,7 +55,7 @@ class SerialisedModel(ABC, CacheMixin):
|
||||
return super().__dir__() + list(sorted(self.ALLOWED_PROPERTIES))
|
||||
|
||||
|
||||
class SerialisedModelCollection(ABC, CacheMixin):
|
||||
class SerialisedModelCollection(ABC):
|
||||
|
||||
"""
|
||||
A SerialisedModelCollection takes a list of dictionaries, typically
|
||||
@@ -68,6 +79,7 @@ class SerialisedModelCollection(ABC, CacheMixin):
|
||||
|
||||
|
||||
class SerialisedTemplate(SerialisedModel):
|
||||
|
||||
ALLOWED_PROPERTIES = {
|
||||
'archived',
|
||||
'content',
|
||||
@@ -80,6 +92,23 @@ class SerialisedTemplate(SerialisedModel):
|
||||
'version',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def from_template_id_and_service_id(cls, template_id, service_id):
|
||||
|
||||
from app.dao.templates_dao import dao_get_template_by_id_and_service_id
|
||||
from app.schemas import template_schema
|
||||
|
||||
fetched_template = dao_get_template_by_id_and_service_id(
|
||||
template_id=template_id,
|
||||
service_id=service_id
|
||||
)
|
||||
|
||||
template_dict = template_schema.dump(fetched_template).data
|
||||
|
||||
db.session.commit()
|
||||
return cls(template_dict)
|
||||
|
||||
|
||||
class SerialisedService(SerialisedModel):
|
||||
ALLOWED_PROPERTIES = {
|
||||
@@ -92,6 +121,14 @@ class SerialisedService(SerialisedModel):
|
||||
'restricted',
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def from_id(cls, service_id):
|
||||
from app.schemas import service_schema
|
||||
with AUTH_DB_CONNECTION_DURATION_SECONDS.time():
|
||||
fetched = dao_fetch_service_by_id(service_id)
|
||||
return cls(service_schema.dump(fetched).data)
|
||||
|
||||
|
||||
class SerialisedAPIKey(SerialisedModel):
|
||||
ALLOWED_PROPERTIES = {
|
||||
@@ -104,3 +141,11 @@ class SerialisedAPIKey(SerialisedModel):
|
||||
|
||||
class SerialisedAPIKeyCollection(SerialisedModelCollection):
|
||||
model = SerialisedAPIKey
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def from_service_id(cls, service_id):
|
||||
return cls([
|
||||
{k: getattr(key, k) for k in SerialisedAPIKey.ALLOWED_PROPERTIES}
|
||||
for key in get_model_api_keys(service_id)
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user