diff --git a/app/dao/organisation_dao.py b/app/dao/organisation_dao.py index bdd430c8c..26b6dca1f 100644 --- a/app/dao/organisation_dao.py +++ b/app/dao/organisation_dao.py @@ -59,12 +59,12 @@ def dao_update_organisation(organisation_id, **kwargs): kwargs ) - if domains: + if isinstance(domains, list): Domain.query.filter_by(organisation_id=organisation_id).delete() db.session.bulk_save_objects([ - Domain(domain=domain, organisation_id=organisation_id) + Domain(domain=domain.lower(), organisation_id=organisation_id) for domain in domains ]) diff --git a/tests/app/dao/test_organisation_dao.py b/tests/app/dao/test_organisation_dao.py index 6375cf7b6..667fc1438 100644 --- a/tests/app/dao/test_organisation_dao.py +++ b/tests/app/dao/test_organisation_dao.py @@ -90,6 +90,35 @@ def test_update_organisation( assert getattr(organisation, attribute) == value +@pytest.mark.parametrize('domain_list, expected_domains', ( + (['abc', 'def'], {'abc', 'def'}), + (['ABC', 'DEF'], {'abc', 'def'}), + ([], set()), + (None, {'123', '456'}), + pytest.param( + ['abc', 'ABC'], {'abc'}, + marks=pytest.mark.xfail(raises=IntegrityError) + ), +)) +def test_update_organisation_domains_lowercases( + notify_db, + notify_db_session, + domain_list, + expected_domains, +): + create_organisation() + + organisation = Organisation.query.one() + + # Seed some domains + dao_update_organisation(organisation.id, domains=['123', '456']) + + # This should overwrite the seeded domains + dao_update_organisation(organisation.id, domains=domain_list) + + assert {domain.domain for domain in organisation.domains} == expected_domains + + def test_add_service_to_organisation(notify_db, notify_db_session, sample_service, sample_organisation): assert sample_organisation.services == []