diff --git a/app/main/views/add_service.py b/app/main/views/add_service.py index 412b57a03..4fb52a377 100644 --- a/app/main/views/add_service.py +++ b/app/main/views/add_service.py @@ -48,7 +48,7 @@ def _create_example_template(service_id): @user_is_gov_user def add_service(): form = CreateServiceForm( - organisation_type=current_user.default_organisation.organisation_type + organisation_type=current_user.default_organisation_type ) heading = 'About your service' @@ -58,7 +58,7 @@ def add_service(): service_id, error = _create_service( service_name, - current_user.default_organisation.organisation_type or form.organisation_type.data, + current_user.default_organisation_type or form.organisation_type.data, email_from, form, ) diff --git a/app/models/user.py b/app/models/user.py index 54d786e2d..ce7f76f54 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -205,6 +205,20 @@ class User(UserMixin): organisations_client.get_organisation_by_domain(self.email_domain) ) + @property + def default_organisation_type(self): + if self.default_organisation: + return self.default_organisation.organisation_type + if self.has_nhs_email_address: + return 'nhs' + return None + + @property + def has_nhs_email_address(self): + return self.email_address.lower().endswith(( + '@nhs.uk', '.nhs.uk', '@nhs.net', '.nhs.net', + )) + def serialize(self): dct = { "id": self.id, diff --git a/app/templates/views/add-service.html b/app/templates/views/add-service.html index ad7c30229..b284e972e 100644 --- a/app/templates/views/add-service.html +++ b/app/templates/views/add-service.html @@ -20,7 +20,7 @@ {{ textbox(form.name, hint="You can change this later") }} - {% if not current_user.default_organisation.organisation_type %} + {% if not current_user.default_organisation_type %} {{ radios(form.organisation_type) }} {% endif %} diff --git a/tests/app/main/views/test_add_service.py b/tests/app/main/views/test_add_service.py index cbc8342bb..e6860cf2f 100644 --- a/tests/app/main/views/test_add_service.py +++ b/tests/app/main/views/test_add_service.py @@ -3,7 +3,7 @@ from flask import session, url_for from app.utils import is_gov_user from tests import organisation_json -from tests.conftest import mock_get_organisation_by_domain +from tests.conftest import mock_get_organisation_by_domain, normalize_spaces def test_non_gov_user_cannot_see_add_service_button( @@ -62,6 +62,11 @@ def test_get_should_not_render_radios_if_org_type_known( assert not page.select('.multiple-choice') +@pytest.mark.parametrize('email_address', ( + # User’s email address doesn’t matter when the organisation is known + 'test@example.gov.uk', + 'test@example.nhs.uk', +)) @pytest.mark.parametrize('inherited, posted, persisted, sms_limit', ( (None, 'central', 'central', 250000), ('central', None, 'central', 250000), @@ -79,10 +84,13 @@ def test_should_add_service_and_redirect_to_tour_when_no_services( mock_create_or_update_free_sms_fragment_limit, mock_get_all_email_branding, inherited, + email_address, posted, persisted, sms_limit, ): + api_user_active.email_address = email_address + client_request.login(api_user_active) mock_get_organisation_by_domain(mocker, organisation_type=inherited) client_request.post( 'main.add_service', @@ -120,6 +128,69 @@ def test_should_add_service_and_redirect_to_tour_when_no_services( mock_create_or_update_free_sms_fragment_limit.assert_called_once_with(101, sms_limit) +def test_add_service_has_to_choose_org_type( + mocker, + client_request, + mock_create_service, + mock_create_service_template, + mock_get_services_with_no_services, + api_user_active, + mock_create_or_update_free_sms_fragment_limit, + mock_get_all_email_branding, +): + mocker.patch( + 'app.organisations_client.get_organisation_by_domain', + return_value=None, + ) + page = client_request.post( + 'main.add_service', + _data={ + 'name': 'testing the post', + }, + _expected_status=200, + ) + assert normalize_spaces(page.select_one('.error-message').text) == ( + 'Not a valid choice' + ) + assert mock_create_service.called is False + assert mock_create_service_template.called is False + assert mock_create_or_update_free_sms_fragment_limit.called is False + + +@pytest.mark.parametrize('email_address', ( + 'test@nhs.net', + 'test@nhs.uk', + 'test@example.NhS.uK', + 'test@EXAMPLE.NHS.NET', + pytest.param( + 'test@not-nhs.uk', + marks=pytest.mark.xfail(raises=AssertionError) + ) +)) +def test_add_service_guesses_org_type_for_unknown_nhs_orgs( + mocker, + client_request, + mock_create_service, + mock_create_service_template, + mock_get_services_with_no_services, + api_user_active, + mock_create_or_update_free_sms_fragment_limit, + mock_get_all_email_branding, + email_address, +): + api_user_active.email_address = email_address + client_request.login(api_user_active) + mocker.patch( + 'app.organisations_client.get_organisation_by_domain', + return_value=None, + ) + client_request.post( + 'main.add_service', + _data={'name': 'example'}, + ) + assert mock_create_service.call_args[1]['organisation_type'] == 'nhs' + + @pytest.mark.parametrize('organisation_type, free_allowance', [ ('central', 250 * 1000), ('local', 25 * 1000),