diff --git a/app/dao/email_branding_dao.py b/app/dao/email_branding_dao.py index 87ccc61cf..d8738e920 100644 --- a/app/dao/email_branding_dao.py +++ b/app/dao/email_branding_dao.py @@ -12,7 +12,7 @@ def dao_get_email_branding_by_id(email_branding_id): def dao_get_email_branding_by_name(email_branding_name): - return EmailBranding.query.filter_by(name=email_branding_name).one() + return EmailBranding.query.filter_by(name=email_branding_name).first() @transactional diff --git a/app/dao/letter_branding_dao.py b/app/dao/letter_branding_dao.py index 20b5c018a..1011e392e 100644 --- a/app/dao/letter_branding_dao.py +++ b/app/dao/letter_branding_dao.py @@ -7,6 +7,10 @@ def dao_get_letter_branding_by_id(letter_branding_id): return LetterBranding.query.filter(LetterBranding.id == letter_branding_id).one() +def dao_get_letter_branding_by_name(letter_branding_name): + return LetterBranding.query.filter_by(name=letter_branding_name).first() + + def dao_get_letter_branding_by_domain(domain): if not domain: return None diff --git a/app/dao/services_dao.py b/app/dao/services_dao.py index b033e9cad..8994d92ff 100644 --- a/app/dao/services_dao.py +++ b/app/dao/services_dao.py @@ -11,6 +11,8 @@ from app.dao.dao_utils import ( transactional, version_class ) +from app.dao.email_branding_dao import dao_get_email_branding_by_name +from app.dao.letter_branding_dao import dao_get_letter_branding_by_name from app.dao.organisation_dao import dao_get_organisation_by_email_address from app.dao.service_sms_sender_dao import insert_service_sms_sender from app.dao.service_user_dao import dao_get_service_user @@ -38,7 +40,7 @@ from app.models import ( SMS_TYPE, LETTER_TYPE, ) -from app.utils import get_london_midnight_in_utc, midnight_n_days_ago +from app.utils import email_address_is_nhs, get_london_midnight_in_utc, midnight_n_days_ago DEFAULT_SERVICE_PERMISSIONS = [ SMS_TYPE, @@ -209,6 +211,12 @@ def dao_create_service( if organisation.letter_branding and not service.letter_branding: service.letter_branding = organisation.letter_branding + if not organisation and ( + service.organisation_type == 'nhs' or email_address_is_nhs(user.email_address) + ): + service.email_branding = dao_get_email_branding_by_name('NHS') + service.letter_branding = dao_get_letter_branding_by_name('NHS') + db.session.add(service) diff --git a/app/utils.py b/app/utils.py index fcd33b54b..3d6ee1804 100644 --- a/app/utils.py +++ b/app/utils.py @@ -122,3 +122,9 @@ def escape_special_characters(string): r'\{}'.format(special_character) ) return string + + +def email_address_is_nhs(email_address): + return email_address.lower().endswith(( + '@nhs.uk', '@nhs.net', '.nhs.uk', '.nhs.net', + )) diff --git a/tests/app/dao/test_services_dao.py b/tests/app/dao/test_services_dao.py index a8665554f..29da6d7eb 100644 --- a/tests/app/dao/test_services_dao.py +++ b/tests/app/dao/test_services_dao.py @@ -67,6 +67,7 @@ from tests.app.db import ( create_notification, create_api_key, create_invited_user, + create_email_branding, create_letter_branding, ) @@ -115,12 +116,56 @@ def test_create_service_with_letter_branding(notify_db_session): organisation_type='central', created_by=user) dao_create_service(service, user, letter_branding=letter_branding) - assert Service.query.count() == 1 service_db = Service.query.one() assert service_db.id == service.id assert service.letter_branding == letter_branding +@pytest.mark.parametrize('email_address, organisation_type', ( + ("test@example.gov.uk", 'nhs'), + ("test@nhs.net", 'nhs'), + ("test@nhs.net", 'local'), + ("test@nhs.net", 'central'), + ("test@nhs.uk", 'central'), + ("test@example.nhs.uk", 'central'), + ("TEST@NHS.UK", 'central'), +)) +@pytest.mark.parametrize('branding_name_to_create, expected_branding', ( + ('NHS', True), + # Need to check that nothing breaks in environments that don’t have + # the NHS branding set up + ('SHN', False), +)) +def test_create_nhs_service_get_default_branding_based_on_email_address( + notify_db_session, + branding_name_to_create, + expected_branding, + email_address, + organisation_type, +): + user = create_user(email=email_address) + letter_branding = create_letter_branding(name=branding_name_to_create) + email_branding = create_email_branding(name=branding_name_to_create) + + service = Service( + name="service_name", + email_from="email_from", + message_limit=1000, + restricted=False, + organisation_type=organisation_type, + created_by=user, + ) + dao_create_service(service, user, letter_branding=letter_branding) + service_db = Service.query.one() + + if expected_branding: + assert service_db.letter_branding == letter_branding + assert service_db.email_branding == email_branding + else: + assert service_db.letter_branding is None + assert service_db.email_branding is None + + def test_cannot_create_two_services_with_same_name(notify_db_session): user = create_user() assert Service.query.count() == 0