diff --git a/app/dao/organisation_dao.py b/app/dao/organisation_dao.py index 9cb5a03e0..b7db3a4b0 100644 --- a/app/dao/organisation_dao.py +++ b/app/dao/organisation_dao.py @@ -1,7 +1,7 @@ from sqlalchemy.sql.expression import func from app import db -from app.dao.dao_utils import transactional, version_class +from app.dao.dao_utils import VersionOptions, transactional, version_class from app.models import ( Organisation, Domain, @@ -77,21 +77,49 @@ def dao_update_organisation(organisation_id, **kwargs): for domain in domains ]) + organisation = Organisation.query.get(organisation_id) + if 'organisation_type' in kwargs: - organisation = Organisation.query.get(organisation_id) - if organisation.services: - _update_org_type_for_organisation_services(organisation) + _update_org_type_for_organisation_services(organisation) + + if 'email_branding_id' in kwargs: + _update_email_branding_for_organisation_services(organisation) + + if 'letter_branding_id' in kwargs: + _update_letter_branding_for_organisation_services(organisation) return num_updated -@version_class(Service) +@version_class( + VersionOptions(Service, must_write_history=False) +) def _update_org_type_for_organisation_services(organisation): for service in organisation.services: service.organisation_type = organisation.organisation_type db.session.add(service) +@version_class( + VersionOptions(Service, must_write_history=False) +) +def _update_email_branding_for_organisation_services(organisation): + for service in organisation.services: + if service.email_branding is None: + service.email_branding = organisation.email_branding + db.session.add(service) + + +@version_class( + VersionOptions(Service, must_write_history=False) +) +def _update_letter_branding_for_organisation_services(organisation): + for service in organisation.services: + if service.letter_branding is None: + service.letter_branding = organisation.letter_branding + db.session.add(service) + + @transactional @version_class(Service) def dao_add_service_to_organisation(service, organisation_id): diff --git a/tests/app/dao/test_organisation_dao.py b/tests/app/dao/test_organisation_dao.py index aae0b3e61..cbe10267b 100644 --- a/tests/app/dao/test_organisation_dao.py +++ b/tests/app/dao/test_organisation_dao.py @@ -119,12 +119,17 @@ def test_update_organisation_domains_lowercases( assert {domain.domain for domain in organisation.domains} == expected_domains -def test_update_organisation_does_not_update_the_service_org_type_if_org_type_is_not_provided( +def test_update_organisation_does_not_update_the_service_if_certain_attributes_not_provided( sample_service, sample_organisation, ): + email_branding = create_email_branding() + letter_branding = create_letter_branding() + sample_service.organisation_type = 'local' sample_organisation.organisation_type = 'central' + sample_organisation.email_branding = email_branding + sample_organisation.letter_branding = letter_branding sample_organisation.services.append(sample_service) db.session.commit() @@ -134,8 +139,16 @@ def test_update_organisation_does_not_update_the_service_org_type_if_org_type_is dao_update_organisation(sample_organisation.id, name='updated org name') assert sample_organisation.name == 'updated org name' + + assert sample_organisation.organisation_type == 'central' assert sample_service.organisation_type == 'local' + assert sample_organisation.email_branding == email_branding + assert sample_service.email_branding is None + + assert sample_organisation.letter_branding == letter_branding + assert sample_service.letter_branding is None + def test_update_organisation_updates_the_service_org_type_if_org_type_is_provided( sample_service, @@ -157,6 +170,49 @@ def test_update_organisation_updates_the_service_org_type_if_org_type_is_provide ).one().organisation_type == 'central' +def test_update_organisation_updates_the_service_branding_if_branding_is_provided( + sample_service, + sample_organisation, +): + email_branding = create_email_branding() + letter_branding = create_letter_branding() + + sample_organisation.services.append(sample_service) + db.session.commit() + + dao_update_organisation(sample_organisation.id, email_branding_id=email_branding.id) + dao_update_organisation(sample_organisation.id, letter_branding_id=letter_branding.id) + + assert sample_organisation.email_branding == email_branding + assert sample_organisation.letter_branding == letter_branding + assert sample_service.email_branding == email_branding + assert sample_service.letter_branding == letter_branding + + +def test_update_organisation_does_not_override_service_branding( + sample_service, + sample_organisation, +): + email_branding = create_email_branding() + custom_email_branding = create_email_branding(name='custom') + letter_branding = create_letter_branding() + custom_letter_branding = create_letter_branding(name='custom', filename='custom') + + sample_service.email_branding = custom_email_branding + sample_service.letter_branding = custom_letter_branding + + sample_organisation.services.append(sample_service) + db.session.commit() + + dao_update_organisation(sample_organisation.id, email_branding_id=email_branding.id) + dao_update_organisation(sample_organisation.id, letter_branding_id=letter_branding.id) + + assert sample_organisation.email_branding == email_branding + assert sample_organisation.letter_branding == letter_branding + assert sample_service.email_branding == custom_email_branding + assert sample_service.letter_branding == custom_letter_branding + + def test_add_service_to_organisation(sample_service, sample_organisation): assert sample_organisation.services == []