diff --git a/app/dao/organisation_dao.py b/app/dao/organisation_dao.py index 80708f266..6082ae376 100644 --- a/app/dao/organisation_dao.py +++ b/app/dao/organisation_dao.py @@ -1,6 +1,6 @@ from app import db from app.dao.dao_utils import transactional -from app.models import Organisation +from app.models import Organisation, Service def dao_get_organisations(): @@ -9,10 +9,20 @@ def dao_get_organisations(): ).all() +def dao_get_organisation_services(organisation_id): + return Organisation.query.filter_by( + id=organisation_id + ).one().services + + def dao_get_organisation_by_id(organisation_id): return Organisation.query.filter_by(id=organisation_id).one() +def dao_get_organisation_by_service_id(service_id): + return Organisation.query.join(Organisation.services).filter(Service.id == service_id).first() + + @transactional def dao_create_organisation(organisation): db.session.add(organisation) @@ -23,3 +33,12 @@ def dao_update_organisation(organisation_id, **kwargs): return Organisation.query.filter_by(id=organisation_id).update( kwargs ) + + +@transactional +def dao_add_service_to_organisation(service, organisation_id): + organisation = Organisation.query.filter_by( + id=organisation_id + ).one() + + organisation.services.append(service) diff --git a/tests/app/conftest.py b/tests/app/conftest.py index bf675ed2b..5da494062 100644 --- a/tests/app/conftest.py +++ b/tests/app/conftest.py @@ -17,6 +17,7 @@ from app.models import ( TemplateHistory, ApiKey, Job, + Organisation, Notification, NotificationHistory, InvitedUser, @@ -40,6 +41,7 @@ from app.models import ( ServiceEmailReplyTo ) from app.dao.users_dao import (create_user_code, create_secret_code) +from app.dao.organisation_dao import dao_create_organisation from app.dao.services_dao import (dao_create_service, dao_add_user_to_service) from app.dao.templates_dao import dao_create_template from app.dao.api_key_dao import save_model_api_key @@ -1044,6 +1046,13 @@ def sample_inbound_numbers(notify_db, notify_db_session, sample_service): return inbound_numbers +@pytest.fixture +def sample_organisation(notify_db, notify_db_session): + org = Organisation(name='sample organisation') + dao_create_organisation(org) + return org + + @pytest.fixture def restore_provider_details(notify_db, notify_db_session): """ diff --git a/tests/app/dao/test_organisation_dao.py b/tests/app/dao/test_organisation_dao.py index a72670cb8..610d71ed5 100644 --- a/tests/app/dao/test_organisation_dao.py +++ b/tests/app/dao/test_organisation_dao.py @@ -1,11 +1,17 @@ +import pytest +from sqlalchemy.exc import IntegrityError + from app.dao.organisation_dao import ( dao_get_organisations, dao_get_organisation_by_id, + dao_get_organisation_by_service_id, + dao_get_organisation_services, dao_update_organisation, + dao_add_service_to_organisation, ) from app.models import Organisation -from tests.app.db import create_organisation +from tests.app.db import create_organisation, create_service def test_get_organisations_gets_all_organisations_alphabetically_with_active_organisations_first( @@ -40,14 +46,60 @@ def test_update_organisation(notify_db, notify_db_session): updated_name = 'new name' create_organisation() - organisation = Organisation.query.all() + organisation = Organisation.query.one() - assert len(organisation) == 1 - assert organisation[0].name != updated_name + assert organisation.name != updated_name - dao_update_organisation(organisation[0].id, **{'name': updated_name}) + dao_update_organisation(organisation.id, **{'name': updated_name}) - organisation = Organisation.query.all() + organisation = Organisation.query.one() - assert len(organisation) == 1 - assert organisation[0].name == updated_name + assert organisation.name == updated_name + + +def test_add_service_to_organisation(notify_db, notify_db_session, sample_service, sample_organisation): + dao_add_service_to_organisation(sample_service, sample_organisation.id) + + assert len(sample_organisation.services) == 1 + + +def test_add_service_to_multiple_organisation_raises_error( + notify_db, notify_db_session, sample_service, sample_organisation): + another_org = create_organisation() + dao_add_service_to_organisation(sample_service, sample_organisation.id) + + with pytest.raises(IntegrityError): + dao_add_service_to_organisation(sample_service, another_org.id) + + assert len(sample_organisation.services) == 1 + assert sample_organisation.services[0] == sample_service + + +def test_get_organisation_services(notify_db, notify_db_session, sample_service, sample_organisation): + another_service = create_service(service_name='service 2') + another_org = create_organisation() + + dao_add_service_to_organisation(sample_service, sample_organisation.id) + dao_add_service_to_organisation(another_service, sample_organisation.id) + + org_services = dao_get_organisation_services(sample_organisation.id) + other_org_services = dao_get_organisation_services(another_org.id) + + assert len(org_services) == 2 + assert org_services[0].name == sample_service.name + assert org_services[1].name == another_service.name + assert not other_org_services + + +def test_get_organisation_by_service_id(notify_db, notify_db_session, sample_service, sample_organisation): + another_service = create_service(service_name='service 2') + another_org = create_organisation() + + dao_add_service_to_organisation(sample_service, sample_organisation.id) + dao_add_service_to_organisation(another_service, another_org.id) + + organisation_1 = dao_get_organisation_by_service_id(str(sample_service.id)) + organisation_2 = dao_get_organisation_by_service_id(str(another_service.id)) + + assert organisation_1 == sample_organisation + assert organisation_2 == another_org