diff --git a/app/dao/organisation_dao.py b/app/dao/organisation_dao.py index b7db3a4b0..f9757664f 100644 --- a/app/dao/organisation_dao.py +++ b/app/dao/organisation_dao.py @@ -80,43 +80,24 @@ def dao_update_organisation(organisation_id, **kwargs): organisation = Organisation.query.get(organisation_id) if 'organisation_type' in kwargs: - _update_org_type_for_organisation_services(organisation) + _update_organisation_services(organisation, 'organisation_type', only_where_none=False) if 'email_branding_id' in kwargs: - _update_email_branding_for_organisation_services(organisation) + _update_organisation_services(organisation, 'email_branding') if 'letter_branding_id' in kwargs: - _update_letter_branding_for_organisation_services(organisation) + _update_organisation_services(organisation, 'letter_branding') return num_updated @version_class( - VersionOptions(Service, must_write_history=False) + VersionOptions(Service, must_write_history=False), ) -def _update_org_type_for_organisation_services(organisation): +def _update_organisation_services(organisation, attribute, only_where_none=True): for service in organisation.services: - service.organisation_type = organisation.organisation_type - db.session.add(service) - - -@version_class( - VersionOptions(Service, must_write_history=False) -) -def _update_email_branding_for_organisation_services(organisation): - for service in organisation.services: - if service.email_branding is None: - service.email_branding = organisation.email_branding - db.session.add(service) - - -@version_class( - VersionOptions(Service, must_write_history=False) -) -def _update_letter_branding_for_organisation_services(organisation): - for service in organisation.services: - if service.letter_branding is None: - service.letter_branding = organisation.letter_branding + if getattr(service, attribute) is None or not only_where_none: + setattr(service, attribute, getattr(organisation, attribute)) db.session.add(service)