diff --git a/app/authentication/auth.py b/app/authentication/auth.py index 1b7a286ac..e26edee48 100644 --- a/app/authentication/auth.py +++ b/app/authentication/auth.py @@ -8,7 +8,7 @@ from sqlalchemy.exc import DataError from sqlalchemy.orm.exc import NoResultFound from gds_metrics import Histogram -from app.dao.services_dao import dao_fetch_service_by_id_with_api_keys +from app.serialised_models import SerialisedService GENERAL_TOKEN_ERROR_MESSAGE = 'Invalid token: make sure your API token matches the example at https://docs.notifications.service.gov.uk/rest-api.html#authorisation-header' # noqa @@ -94,7 +94,7 @@ def requires_auth(): try: with AUTH_DB_CONNECTION_DURATION_SECONDS.time(): - service = dao_fetch_service_by_id_with_api_keys(issuer) + service = SerialisedService.from_id(issuer) except DataError: raise AuthError("Invalid token: service id is not the right data type", 403) except NoResultFound: @@ -129,7 +129,7 @@ def requires_auth(): if api_key.expiry_date: raise AuthError("Invalid token: API key revoked", 403, service_id=service.id, api_key_id=api_key.id) - g.service_id = api_key.service_id + g.service_id = service.id _request_ctx_stack.top.authenticated_service = service _request_ctx_stack.top.api_user = api_key diff --git a/app/config.py b/app/config.py index 7a9f881d3..257eaea61 100644 --- a/app/config.py +++ b/app/config.py @@ -400,7 +400,12 @@ class Test(Development): NOTIFY_ENVIRONMENT = 'test' TESTING = True - HIGH_VOLUME_SERVICE = ['941b6f9a-50d7-4742-8d50-f365ca74bf27'] + HIGH_VOLUME_SERVICE = [ + '941b6f9a-50d7-4742-8d50-f365ca74bf27', + '63f95b86-2d19-4497-b8b2-ccf25457df4e', + '7e5950cb-9954-41f5-8376-962b8c8555cf', + '10d1b9c9-0072-4fa9-ae1c-595e333841da', + ] CSV_UPLOAD_BUCKET_NAME = 'test-notifications-csv-upload' CONTACT_LIST_BUCKET_NAME = 'test-contact-list' diff --git a/app/notifications/process_notifications.py b/app/notifications/process_notifications.py index 3f3f86fd6..e5ca0adcc 100644 --- a/app/notifications/process_notifications.py +++ b/app/notifications/process_notifications.py @@ -120,7 +120,6 @@ def persist_notification( template_version=template_version, to=recipient, service_id=service.id, - service=service, personalisation=personalisation, notification_type=notification_type, api_key_id=api_key_id, diff --git a/app/notifications/validators.py b/app/notifications/validators.py index c505d5a19..d1bf7b589 100644 --- a/app/notifications/validators.py +++ b/app/notifications/validators.py @@ -90,7 +90,7 @@ def service_can_send_to_recipient(send_to, key_type, service, allow_whitelisted_ def service_has_permission(notify_type, permissions): - return notify_type in [p.permission for p in permissions] + return notify_type in permissions def check_service_has_permission(notify_type, permissions): @@ -131,7 +131,7 @@ def check_if_service_can_send_to_number(service, number): if ( # if number is international and not a crown dependency international_phone_info.international and not international_phone_info.crown_dependency - ) and INTERNATIONAL_SMS_TYPE not in [p.permission for p in service.permissions]: + ) and INTERNATIONAL_SMS_TYPE not in service.permissions: raise BadRequestError(message="Cannot send to international mobile numbers") else: return international_phone_info diff --git a/app/schemas.py b/app/schemas.py index 36d9ff815..763f34143 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -205,7 +205,7 @@ class ProviderDetailsHistorySchema(BaseSchema): strict = True -class ServiceSchema(BaseSchema): +class ServiceSchema(BaseSchema, UUIDsAsStringsMixin): created_by = field_for(models.Service, 'created_by', required=True) organisation_type = field_for(models.Service, 'organisation_type') diff --git a/app/serialised_models.py b/app/serialised_models.py index d544dc478..7f88171d2 100644 --- a/app/serialised_models.py +++ b/app/serialised_models.py @@ -5,9 +5,13 @@ from threading import RLock import cachetools from notifications_utils.clients.redis import RequestCache +from werkzeug.utils import cached_property + +from app import db, redis_store -from app import redis_store from app.dao import templates_dao +from app.dao.api_key_dao import get_model_api_keys +from app.dao.services_dao import dao_fetch_service_by_id caches = defaultdict(partial(cachetools.TTLCache, maxsize=1024, ttl=2)) locks = defaultdict(RLock) @@ -53,6 +57,29 @@ class SerialisedModel(ABC): return super().__dir__() + list(sorted(self.ALLOWED_PROPERTIES)) +class SerialisedModelCollection(ABC): + + """ + A SerialisedModelCollection takes a list of dictionaries, typically + created by serialising database objects. When iterated over it + returns a SerialisedModel instance for each of the items in the list. + """ + + @property + @abstractmethod + def model(self): + pass + + def __init__(self, items): + self.items = items + + def __bool__(self): + return bool(self.items) + + def __getitem__(self, index): + return self.model(self.items[index]) + + class SerialisedTemplate(SerialisedModel): ALLOWED_PROPERTIES = { 'archived', @@ -82,5 +109,60 @@ class SerialisedTemplate(SerialisedModel): ) template_dict = template_schema.dump(fetched_template).data + db.session.commit() return {'data': template_dict} + + +class SerialisedService(SerialisedModel): + ALLOWED_PROPERTIES = { + 'id', + 'active', + 'contact_link', + 'email_from', + 'permissions', + 'research_mode', + 'restricted', + } + + @classmethod + @memory_cache + def from_id(cls, service_id): + return cls(cls.get_dict(service_id)['data']) + + @staticmethod + @redis_cache.set('service-{service_id}') + def get_dict(service_id): + from app.schemas import service_schema + + service_dict = service_schema.dump(dao_fetch_service_by_id(service_id)).data + db.session.commit() + + return {'data': service_dict} + + @cached_property + def api_keys(self): + return SerialisedAPIKeyCollection.from_service_id(self.id) + + +class SerialisedAPIKey(SerialisedModel): + ALLOWED_PROPERTIES = { + 'id', + 'secret', + 'expiry_date', + 'key_type', + } + + +class SerialisedAPIKeyCollection(SerialisedModelCollection): + model = SerialisedAPIKey + + @classmethod + @memory_cache + def from_service_id(cls, service_id): + keys = [ + {k: getattr(key, k) for k in SerialisedAPIKey.ALLOWED_PROPERTIES} + for key in get_model_api_keys(service_id) + ] + db.session.commit() + return cls(keys) diff --git a/app/service/send_notification.py b/app/service/send_notification.py index a7886b624..10ed42ab7 100644 --- a/app/service/send_notification.py +++ b/app/service/send_notification.py @@ -137,7 +137,9 @@ def get_reply_to_text(notification_type, sender_id, service, template): def send_pdf_letter_notification(service_id, post_data): service = dao_fetch_service_by_id(service_id) - check_service_has_permission(LETTER_TYPE, service.permissions) + check_service_has_permission(LETTER_TYPE, [ + p.permission for p in service.permissions + ]) check_service_over_daily_message_limit(KEY_TYPE_NORMAL, service) validate_created_by(service, post_data['created_by']) validate_and_format_recipient( diff --git a/app/service/utils.py b/app/service/utils.py index 16521fbbd..c6e0e0476 100644 --- a/app/service/utils.py +++ b/app/service/utils.py @@ -7,6 +7,8 @@ from app.models import ( MOBILE_TYPE, EMAIL_TYPE, KEY_TYPE_TEST, KEY_TYPE_TEAM, KEY_TYPE_NORMAL) +from app.dao.services_dao import dao_fetch_service_by_id + def get_recipients_from_request(request_json, key, type): return [(type, recipient) for recipient in request_json.get(key)] @@ -33,6 +35,10 @@ def service_allowed_to_send_to(recipient, service, key_type, allow_whitelisted_r if key_type == KEY_TYPE_NORMAL and not service.restricted: return True + # Revert back to the ORM model here so we can get some things which + # aren’t in the serialised model + service = dao_fetch_service_by_id(service.id) + team_members = itertools.chain.from_iterable( [user.mobile_number, user.email_address] for user in service.users ) diff --git a/app/template/rest.py b/app/template/rest.py index da401094b..464fad1c3 100644 --- a/app/template/rest.py +++ b/app/template/rest.py @@ -69,7 +69,9 @@ def validate_parent_folder(template_json): def create_template(service_id): fetched_service = dao_fetch_service_by_id(service_id=service_id) # permissions needs to be placed here otherwise marshmallow will interfere with versioning - permissions = fetched_service.permissions + permissions = [ + p.permission for p in fetched_service.permissions + ] template_json = validate(request.get_json(), post_create_template_schema) folder = validate_parent_folder(template_json=template_json) new_template = Template.from_json(template_json, folder) @@ -102,7 +104,12 @@ def create_template(service_id): def update_template(service_id, template_id): fetched_template = dao_get_template_by_id_and_service_id(template_id=template_id, service_id=service_id) - if not service_has_permission(fetched_template.template_type, fetched_template.service.permissions): + if not service_has_permission( + fetched_template.template_type, + [ + p.permission for p in fetched_template.service.permissions + ] + ): message = "Updating {} templates is not allowed".format( get_public_notify_type_text(fetched_template.template_type)) errors = {'template_type': [message]} diff --git a/app/v2/notifications/post_notifications.py b/app/v2/notifications/post_notifications.py index cc77be913..af0d766ea 100644 --- a/app/v2/notifications/post_notifications.py +++ b/app/v2/notifications/post_notifications.py @@ -332,7 +332,7 @@ def process_letter_notification( if api_key.key_type == KEY_TYPE_TEAM: raise BadRequestError(message='Cannot send letters with a team api key', status_code=403) - if not api_key.service.research_mode and api_key.service.restricted and api_key.key_type != KEY_TYPE_TEST: + if not service.research_mode and service.restricted and api_key.key_type != KEY_TYPE_TEST: raise BadRequestError(message='Cannot send letters when service is in trial mode', status_code=403) if precompiled: @@ -342,7 +342,7 @@ def process_letter_notification( template=template, reply_to_text=reply_to_text) - validate_address(api_key, letter_data) + validate_address(service, letter_data) test_key = api_key.key_type == KEY_TYPE_TEST @@ -391,10 +391,10 @@ def process_letter_notification( return resp -def validate_address(api_key, letter_data): +def validate_address(service, letter_data): address = PostalAddress.from_personalisation( letter_data['personalisation'], - allow_international_letters=api_key.service.has_permission(INTERNATIONAL_LETTERS), + allow_international_letters=(INTERNATIONAL_LETTERS in service.permissions), ) if not address.has_enough_lines: raise ValidationError( diff --git a/requirements-app.txt b/requirements-app.txt index 57b04538e..444164308 100644 --- a/requirements-app.txt +++ b/requirements-app.txt @@ -20,6 +20,7 @@ marshmallow==2.21.0 # pyup: <3 # v3 throws errors psycopg2-binary==2.8.5 PyJWT==1.7.1 SQLAlchemy==1.3.17 +cachetools==4.1.0 notifications-python-client==5.5.1 diff --git a/requirements.txt b/requirements.txt index c12e5dc4b..c70959d8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ marshmallow==2.21.0 # pyup: <3 # v3 throws errors psycopg2-binary==2.8.5 PyJWT==1.7.1 SQLAlchemy==1.3.17 +cachetools==4.1.0 notifications-python-client==5.5.1 @@ -39,15 +40,14 @@ alembic==1.4.2 amqp==1.4.9 anyjson==0.3.3 attrs==19.3.0 -awscli==1.18.84 +awscli==1.18.85 bcrypt==3.1.7 billiard==3.3.0.23 bleach==3.1.4 blinker==1.4 boto==2.49.0 boto3==1.10.38 -botocore==1.17.7 -cachetools==4.1.0 +botocore==1.17.8 certifi==2020.6.20 chardet==3.0.4 click==7.1.2 diff --git a/tests/app/authentication/test_authentication.py b/tests/app/authentication/test_authentication.py index 3b6b086aa..24b212f4c 100644 --- a/tests/app/authentication/test_authentication.py +++ b/tests/app/authentication/test_authentication.py @@ -8,9 +8,18 @@ import pytest from flask import json, current_app, request from freezegun import freeze_time from notifications_python_client.authentication import create_jwt_token +from unittest.mock import call from app import api_user -from app.dao.api_key_dao import get_unsigned_secrets, save_model_api_key, get_unsigned_secret, expire_api_key +from app.dao.api_key_dao import ( + get_unsigned_secrets, + save_model_api_key, + get_unsigned_secret, + expire_api_key, + get_model_api_keys, +) +from app.dao.services_dao import dao_fetch_service_by_id + from app.models import ApiKey, KEY_TYPE_NORMAL from app.authentication.auth import AuthError, requires_admin_auth, requires_auth, GENERAL_TOKEN_ERROR_MESSAGE @@ -300,7 +309,7 @@ def test_authentication_returns_token_expired_when_service_uses_expired_key_and_ with pytest.raises(AuthError) as exc: requires_auth() assert exc.value.short_message == 'Invalid token: API key revoked' - assert exc.value.service_id == expired_api_key.service_id + assert exc.value.service_id == str(expired_api_key.service_id) assert exc.value.api_key_id == expired_api_key.id @@ -376,7 +385,7 @@ def test_authentication_returns_error_when_service_has_no_secrets(client, with pytest.raises(AuthError) as exc: requires_auth() assert exc.value.short_message == 'Invalid token: service has no API keys' - assert exc.value.service_id == sample_service.id + assert exc.value.service_id == str(sample_service.id) def test_should_attach_the_current_api_key_to_current_app(notify_api, sample_service, sample_api_key): @@ -387,7 +396,7 @@ def test_should_attach_the_current_api_key_to_current_app(notify_api, sample_ser headers={'Authorization': 'Bearer {}'.format(token)} ) assert response.status_code == 200 - assert api_user == sample_api_key + assert str(api_user.id) == str(sample_api_key.id) def test_should_return_403_when_token_is_expired(client, @@ -399,8 +408,8 @@ def test_should_return_403_when_token_is_expired(client, request.headers = {'Authorization': 'Bearer {}'.format(token)} requires_auth() assert exc.value.short_message == 'Error: Your system clock must be accurate to within 30 seconds' - assert exc.value.service_id == sample_api_key.service_id - assert exc.value.api_key_id == sample_api_key.id + assert exc.value.service_id == str(sample_api_key.service_id) + assert str(exc.value.api_key_id) == str(sample_api_key.id) def __create_token(service_id): @@ -457,3 +466,28 @@ def test_proxy_key_on_admin_auth_endpoint(notify_api, check_proxy_header, header ] ) assert response.status_code == expected_status + + +def test_should_cache_service_and_api_key_lookups(mocker, client, sample_api_key): + + mock_get_api_keys = mocker.patch( + 'app.serialised_models.get_model_api_keys', + wraps=get_model_api_keys, + ) + mock_get_service = mocker.patch( + 'app.serialised_models.dao_fetch_service_by_id', + wraps=dao_fetch_service_by_id, + ) + + for i in range(5): + token = __create_token(sample_api_key.service_id) + client.get('/notifications', headers={ + 'Authorization': f'Bearer {token}' + }) + + assert mock_get_api_keys.call_args_list == [ + call(str(sample_api_key.service_id)) + ] + assert mock_get_service.call_args_list == [ + call(str(sample_api_key.service_id)) + ] diff --git a/tests/app/notifications/test_validators.py b/tests/app/notifications/test_validators.py index 1c4db4190..ce3b0ce7a 100644 --- a/tests/app/notifications/test_validators.py +++ b/tests/app/notifications/test_validators.py @@ -23,7 +23,7 @@ from app.notifications.validators import ( validate_and_format_recipient, validate_template, ) -from app.serialised_models import SerialisedTemplate +from app.serialised_models import SerialisedService, SerialisedTemplate from app.utils import get_template_instance from app.v2.errors import ( @@ -439,8 +439,9 @@ def test_rejects_api_calls_with_international_numbers_if_service_does_not_allow_ notify_db_session, ): service = create_service(service_permissions=[SMS_TYPE]) + service_model = SerialisedService.from_id(service.id) with pytest.raises(BadRequestError) as e: - validate_and_format_recipient('20-12-1234-1234', key_type, service, SMS_TYPE) + validate_and_format_recipient('20-12-1234-1234', key_type, service_model, SMS_TYPE) assert e.value.status_code == 400 assert e.value.message == 'Cannot send to international mobile numbers' assert e.value.fields == [] @@ -449,7 +450,8 @@ def test_rejects_api_calls_with_international_numbers_if_service_does_not_allow_ @pytest.mark.parametrize('key_type', ['test', 'normal']) def test_allows_api_calls_with_international_numbers_if_service_does_allow_int_sms( key_type, sample_service_full_permissions): - result = validate_and_format_recipient('20-12-1234-1234', key_type, sample_service_full_permissions, SMS_TYPE) + service_model = SerialisedService.from_id(sample_service_full_permissions.id) + result = validate_and_format_recipient('20-12-1234-1234', key_type, service_model, SMS_TYPE) assert result == '201212341234' diff --git a/tests/app/v2/notifications/test_post_notifications.py b/tests/app/v2/notifications/test_post_notifications.py index 536c4b51f..6d1ea71c2 100644 --- a/tests/app/v2/notifications/test_post_notifications.py +++ b/tests/app/v2/notifications/test_post_notifications.py @@ -232,14 +232,14 @@ def test_should_cache_template_lookups_in_memory(mocker, client, sample_template assert mock_get_template.call_count == 1 assert mock_get_template.call_args_list == [ - call(service_id=sample_template.service_id, template_id=str(sample_template.id)) + call(service_id=str(sample_template.service_id), template_id=str(sample_template.id)) ] assert Notification.query.count() == 5 -def test_should_cache_template_lookups_in_redis(mocker, client, sample_template): +def test_should_cache_template_and_service_in_redis(mocker, client, sample_template): - from app.schemas import template_schema + from app.schemas import service_schema, template_schema mock_redis_get = mocker.patch( 'app.redis_store.get', @@ -263,34 +263,49 @@ def test_should_cache_template_lookups_in_redis(mocker, client, sample_template) headers=[('Content-Type', 'application/json'), auth_header] ) - expected_key = f'template-{sample_template.id}-version-None' + expected_service_key = f'service-{sample_template.service_id}' + expected_templates_key = f'template-{sample_template.id}-version-None' - assert mock_redis_get.call_args_list == [call( - expected_key, - )] + assert mock_redis_get.call_args_list == [ + call(expected_service_key), + call(expected_templates_key), + ] + service_dict = service_schema.dump(sample_template.service).data template_dict = template_schema.dump(sample_template).data - assert len(mock_redis_set.call_args_list) == 1 - assert mock_redis_set.call_args[0][0] == expected_key - assert json.loads(mock_redis_set.call_args[0][1]) == { - 'data': template_dict, - } - assert mock_redis_set.call_args[1]['ex'] == 604_800 + assert len(mock_redis_set.call_args_list) == 2 + + service_call, templates_call = mock_redis_set.call_args_list + + assert service_call[0][0] == expected_service_key + assert json.loads(service_call[0][1]) == {'data': service_dict} + assert service_call[1]['ex'] == 604_800 + + assert templates_call[0][0] == expected_templates_key + assert json.loads(templates_call[0][1]) == {'data': template_dict} + assert templates_call[1]['ex'] == 604_800 def test_should_return_template_if_found_in_redis(mocker, client, sample_template): - from app.schemas import template_schema + from app.schemas import service_schema, template_schema + service_dict = service_schema.dump(sample_template.service).data template_dict = template_schema.dump(sample_template).data mocker.patch( 'app.redis_store.get', - return_value=json.dumps({'data': template_dict}).encode('utf-8') + side_effect=[ + json.dumps({'data': service_dict}).encode('utf-8'), + json.dumps({'data': template_dict}).encode('utf-8'), + ], ) mock_get_template = mocker.patch( 'app.dao.templates_dao.dao_get_template_by_id_and_service_id' ) + mock_get_service = mocker.patch( + 'app.dao.services_dao.dao_fetch_service_by_id' + ) mocker.patch('app.celery.provider_tasks.deliver_sms.apply_async') @@ -308,6 +323,7 @@ def test_should_return_template_if_found_in_redis(mocker, client, sample_templat assert response.status_code == 201 assert mock_get_template.called is False + assert mock_get_service.called is False @pytest.mark.parametrize("notification_type, key_send_to, send_to", @@ -867,8 +883,8 @@ def test_post_notification_with_document_upload(client, notify_db_session, mocke assert validate(resp_json, post_email_response) == resp_json assert document_download_mock.upload_document.call_args_list == [ - call(service.id, 'abababab', csv_param.get('is_csv')), - call(service.id, 'cdcdcdcd', csv_param.get('is_csv')) + call(str(service.id), 'abababab', csv_param.get('is_csv')), + call(str(service.id), 'cdcdcdcd', csv_param.get('is_csv')) ] notification = Notification.query.one() @@ -1017,7 +1033,10 @@ def test_post_notifications_saves_email_to_queue(client, notify_db_session, mock save_email_task = mocker.patch("app.celery.tasks.save_api_email.apply_async") mock_send_task = mocker.patch('app.celery.provider_tasks.deliver_email.apply_async') - service = create_service(service_id='941b6f9a-50d7-4742-8d50-f365ca74bf27', service_name='high volume service') + service = create_service( + service_id=current_app.config['HIGH_VOLUME_SERVICE'][0], + service_name='high volume service', + ) template = create_template(service=service, content='((message))', template_type=EMAIL_TYPE) data = { "email_address": "joe.citizen@example.com", @@ -1048,7 +1067,10 @@ def test_post_notifications_saves_email_normally_if_save_email_to_queue_fails(cl ) mock_send_task = mocker.patch('app.celery.provider_tasks.deliver_email.apply_async') - service = create_service(service_id='941b6f9a-50d7-4742-8d50-f365ca74bf27', service_name='high volume service') + service = create_service( + service_id=current_app.config['HIGH_VOLUME_SERVICE'][1], + service_name='high volume service', + ) template = create_template(service=service, content='((message))', template_type=EMAIL_TYPE) data = { "email_address": "joe.citizen@example.com", @@ -1077,8 +1099,10 @@ def test_post_notifications_doesnt_save_email_to_queue_for_test_emails(client, n save_email_task = mocker.patch("app.celery.tasks.save_api_email.apply_async") mock_send_task = mocker.patch('app.celery.provider_tasks.deliver_email.apply_async') - service = create_service(service_id='941b6f9a-50d7-4742-8d50-f365ca74bf27', service_name='high volume service') - # create_api_key(service=service, key_type='test') + service = create_service( + service_id=current_app.config['HIGH_VOLUME_SERVICE'][2], + service_name='high volume service', + ) template = create_template(service=service, content='((message))', template_type=EMAIL_TYPE) data = { "email_address": "joe.citizen@example.com", @@ -1107,7 +1131,10 @@ def test_post_notifications_doesnt_save_email_to_queue_for_sms(client, notify_db save_email_task = mocker.patch("app.celery.tasks.save_api_email.apply_async") mock_send_task = mocker.patch('app.celery.provider_tasks.deliver_sms.apply_async') - service = create_service(service_id='941b6f9a-50d7-4742-8d50-f365ca74bf27', service_name='high volume service') + service = create_service( + service_id=current_app.config['HIGH_VOLUME_SERVICE'][3], + service_name='high volume service', + ) template = create_template(service=service, content='((message))', template_type=SMS_TYPE) data = { "phone_number": '+447700900855',