Set org_type of a service when adding/updating org org type

The organisation_type of a service should match the organisation_type of
the service's organisation (if there is one). This changes
dao_update_organisation and dao_add_service_to_organisation to set the
organisation_type of any services when adding / updating an org.
This commit is contained in:
Katie Smith
2019-07-04 14:48:16 +01:00
parent 86f14563d0
commit 1671221642
2 changed files with 72 additions and 15 deletions

View File

@@ -1,7 +1,7 @@
from sqlalchemy.sql.expression import func
from app import db
from app.dao.dao_utils import transactional
from app.dao.dao_utils import transactional, version_class
from app.models import (
Organisation,
Domain,
@@ -77,18 +77,31 @@ def dao_update_organisation(organisation_id, **kwargs):
for domain in domains
])
db.session.commit()
if 'organisation_type' in kwargs:
organisation = Organisation.query.get(organisation_id)
if organisation.services:
_update_org_type_for_organisation_services(organisation)
return num_updated
@version_class(Service)
def _update_org_type_for_organisation_services(organisation):
for service in organisation.services:
service.organisation_type = organisation.organisation_type
db.session.add(service)
@transactional
@version_class(Service)
def dao_add_service_to_organisation(service, organisation_id):
organisation = Organisation.query.filter_by(
id=organisation_id
).one()
organisation.services.append(service)
service.organisation_type = organisation.organisation_type
db.session.add(service)
def dao_get_invited_organisation_user(user_id):

View File

@@ -4,6 +4,7 @@ import uuid
import pytest
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from app import db
from app.dao.organisation_dao import (
dao_get_organisations,
dao_get_organisation_by_email_address,
@@ -16,7 +17,7 @@ from app.dao.organisation_dao import (
dao_get_users_for_organisation,
dao_add_user_to_organisation
)
from app.models import Organisation
from app.models import Organisation, Service
from tests.app.db import (
create_domain,
@@ -29,7 +30,6 @@ from tests.app.db import (
def test_get_organisations_gets_all_organisations_alphabetically_with_active_organisations_first(
notify_db,
notify_db_session
):
m_active_org = create_organisation(name='m_active_organisation')
@@ -48,7 +48,7 @@ def test_get_organisations_gets_all_organisations_alphabetically_with_active_org
assert organisations[4] == z_inactive_org
def test_get_organisation_by_id_gets_correct_organisation(notify_db, notify_db_session):
def test_get_organisation_by_id_gets_correct_organisation(notify_db_session):
organisation = create_organisation()
organisation_from_db = dao_get_organisation_by_id(organisation.id)
@@ -56,10 +56,7 @@ def test_get_organisation_by_id_gets_correct_organisation(notify_db, notify_db_s
assert organisation_from_db == organisation
def test_update_organisation(
notify_db,
notify_db_session,
):
def test_update_organisation(notify_db_session):
create_organisation()
organisation = Organisation.query.one()
@@ -82,6 +79,8 @@ def test_update_organisation(
for attribute, value in data.items():
assert getattr(organisation, attribute) != value
assert organisation.updated_at is None
dao_update_organisation(organisation.id, **data)
organisation = Organisation.query.one()
@@ -89,6 +88,8 @@ def test_update_organisation(
for attribute, value in data.items():
assert getattr(organisation, attribute) == value
assert organisation.updated_at
@pytest.mark.parametrize('domain_list, expected_domains', (
(['abc', 'def'], {'abc', 'def'}),
@@ -101,7 +102,6 @@ def test_update_organisation(
),
))
def test_update_organisation_domains_lowercases(
notify_db,
notify_db_session,
domain_list,
expected_domains,
@@ -119,17 +119,61 @@ def test_update_organisation_domains_lowercases(
assert {domain.domain for domain in organisation.domains} == expected_domains
def test_add_service_to_organisation(notify_db, notify_db_session, sample_service, sample_organisation):
def test_update_organisation_does_not_update_the_service_org_type_if_org_type_is_not_provided(
sample_service,
sample_organisation,
):
sample_service.organisation_type = 'local'
sample_organisation.organisation_type = 'central'
sample_organisation.services.append(sample_service)
db.session.commit()
assert sample_organisation.name == 'sample organisation'
dao_update_organisation(sample_organisation.id, name='updated org name')
assert sample_organisation.name == 'updated org name'
assert sample_service.organisation_type == 'local'
def test_update_organisation_updates_the_service_org_type_if_org_type_is_provided(
sample_service,
sample_organisation,
):
sample_service.organisation_type = 'local'
sample_organisation.organisation_type = 'local'
sample_organisation.services.append(sample_service)
db.session.commit()
dao_update_organisation(sample_organisation.id, organisation_type='central')
assert sample_organisation.organisation_type == 'central'
assert sample_service.organisation_type == 'central'
assert Service.get_history_model().query.filter_by(
id=sample_service.id,
version=2
).one().organisation_type == 'central'
def test_add_service_to_organisation(sample_service, sample_organisation):
sample_service.organisation_type = 'local'
sample_organisation.organisation_type = 'central'
assert sample_organisation.services == []
dao_add_service_to_organisation(sample_service, sample_organisation.id)
assert len(sample_organisation.services) == 1
assert sample_organisation.services[0].id == sample_service.id
assert sample_organisation.services[0].organisation_type == 'central'
assert Service.get_history_model().query.filter_by(
id=sample_service.id,
version=2
).one().organisation_type == 'central'
def test_add_service_to_multiple_organisation_raises_error(
notify_db, notify_db_session, sample_service, sample_organisation):
def test_add_service_to_multiple_organisation_raises_error(sample_service, sample_organisation):
another_org = create_organisation()
dao_add_service_to_organisation(sample_service, sample_organisation.id)
@@ -140,7 +184,7 @@ def test_add_service_to_multiple_organisation_raises_error(
assert sample_organisation.services[0] == sample_service
def test_get_organisation_services(notify_db, notify_db_session, sample_service, sample_organisation):
def test_get_organisation_services(sample_service, sample_organisation):
another_service = create_service(service_name='service 2')
another_org = create_organisation()
@@ -154,7 +198,7 @@ def test_get_organisation_services(notify_db, notify_db_session, sample_service,
assert not other_org_services
def test_get_organisation_by_service_id(notify_db, notify_db_session, sample_service, sample_organisation):
def test_get_organisation_by_service_id(sample_service, sample_organisation):
another_service = create_service(service_name='service 2')
another_org = create_organisation()