diff --git a/app/dao/organisation_dao.py b/app/dao/organisation_dao.py index 9cb5a03e0..f9757664f 100644 --- a/app/dao/organisation_dao.py +++ b/app/dao/organisation_dao.py @@ -1,7 +1,7 @@ from sqlalchemy.sql.expression import func from app import db -from app.dao.dao_utils import transactional, version_class +from app.dao.dao_utils import VersionOptions, transactional, version_class from app.models import ( Organisation, Domain, @@ -77,18 +77,27 @@ def dao_update_organisation(organisation_id, **kwargs): for domain in domains ]) + organisation = Organisation.query.get(organisation_id) + if 'organisation_type' in kwargs: - organisation = Organisation.query.get(organisation_id) - if organisation.services: - _update_org_type_for_organisation_services(organisation) + _update_organisation_services(organisation, 'organisation_type', only_where_none=False) + + if 'email_branding_id' in kwargs: + _update_organisation_services(organisation, 'email_branding') + + if 'letter_branding_id' in kwargs: + _update_organisation_services(organisation, 'letter_branding') return num_updated -@version_class(Service) -def _update_org_type_for_organisation_services(organisation): +@version_class( + VersionOptions(Service, must_write_history=False), +) +def _update_organisation_services(organisation, attribute, only_where_none=True): for service in organisation.services: - service.organisation_type = organisation.organisation_type + if getattr(service, attribute) is None or not only_where_none: + setattr(service, attribute, getattr(organisation, attribute)) db.session.add(service) diff --git a/tests/app/dao/test_organisation_dao.py b/tests/app/dao/test_organisation_dao.py index aae0b3e61..cbe10267b 100644 --- a/tests/app/dao/test_organisation_dao.py +++ b/tests/app/dao/test_organisation_dao.py @@ -119,12 +119,17 @@ def test_update_organisation_domains_lowercases( assert {domain.domain for domain in organisation.domains} == expected_domains -def test_update_organisation_does_not_update_the_service_org_type_if_org_type_is_not_provided( +def test_update_organisation_does_not_update_the_service_if_certain_attributes_not_provided( sample_service, sample_organisation, ): + email_branding = create_email_branding() + letter_branding = create_letter_branding() + sample_service.organisation_type = 'local' sample_organisation.organisation_type = 'central' + sample_organisation.email_branding = email_branding + sample_organisation.letter_branding = letter_branding sample_organisation.services.append(sample_service) db.session.commit() @@ -134,8 +139,16 @@ def test_update_organisation_does_not_update_the_service_org_type_if_org_type_is dao_update_organisation(sample_organisation.id, name='updated org name') assert sample_organisation.name == 'updated org name' + + assert sample_organisation.organisation_type == 'central' assert sample_service.organisation_type == 'local' + assert sample_organisation.email_branding == email_branding + assert sample_service.email_branding is None + + assert sample_organisation.letter_branding == letter_branding + assert sample_service.letter_branding is None + def test_update_organisation_updates_the_service_org_type_if_org_type_is_provided( sample_service, @@ -157,6 +170,49 @@ def test_update_organisation_updates_the_service_org_type_if_org_type_is_provide ).one().organisation_type == 'central' +def test_update_organisation_updates_the_service_branding_if_branding_is_provided( + sample_service, + sample_organisation, +): + email_branding = create_email_branding() + letter_branding = create_letter_branding() + + sample_organisation.services.append(sample_service) + db.session.commit() + + dao_update_organisation(sample_organisation.id, email_branding_id=email_branding.id) + dao_update_organisation(sample_organisation.id, letter_branding_id=letter_branding.id) + + assert sample_organisation.email_branding == email_branding + assert sample_organisation.letter_branding == letter_branding + assert sample_service.email_branding == email_branding + assert sample_service.letter_branding == letter_branding + + +def test_update_organisation_does_not_override_service_branding( + sample_service, + sample_organisation, +): + email_branding = create_email_branding() + custom_email_branding = create_email_branding(name='custom') + letter_branding = create_letter_branding() + custom_letter_branding = create_letter_branding(name='custom', filename='custom') + + sample_service.email_branding = custom_email_branding + sample_service.letter_branding = custom_letter_branding + + sample_organisation.services.append(sample_service) + db.session.commit() + + dao_update_organisation(sample_organisation.id, email_branding_id=email_branding.id) + dao_update_organisation(sample_organisation.id, letter_branding_id=letter_branding.id) + + assert sample_organisation.email_branding == email_branding + assert sample_organisation.letter_branding == letter_branding + assert sample_service.email_branding == custom_email_branding + assert sample_service.letter_branding == custom_letter_branding + + def test_add_service_to_organisation(sample_service, sample_organisation): assert sample_organisation.services == []