diff --git a/app/main/views/add_service.py b/app/main/views/add_service.py index 45314ad20..72ee367f4 100644 --- a/app/main/views/add_service.py +++ b/app/main/views/add_service.py @@ -5,6 +5,7 @@ from werkzeug.exceptions import abort from app import ( billing_api_client, + email_branding_client, invite_api_client, service_api_client, user_api_client, @@ -12,7 +13,7 @@ from app import ( from app.main import main from app.main.forms import CreateServiceForm from app.notify_client.models import InvitedUser -from app.utils import email_safe, is_gov_user +from app.utils import AgreementInfo, email_safe, is_gov_user def _add_invited_user_to_service(invited_user): @@ -27,6 +28,9 @@ def _add_invited_user_to_service(invited_user): def _create_service(service_name, organisation_type, email_from, form): free_sms_fragment_limit = current_app.config['DEFAULT_FREE_SMS_FRAGMENT_LIMITS'].get(organisation_type) + email_branding = email_branding_client.get_email_branding_id_for_domain( + AgreementInfo.from_current_user().canonical_domain + ) try: service_id = service_api_client.create_service( service_name=service_name, @@ -38,6 +42,9 @@ def _create_service(service_name, organisation_type, email_from, form): ) session['service_id'] = service_id + if email_branding: + service_api_client.update_service(service_id, email_branding=email_branding) + billing_api_client.create_or_update_free_sms_fragment_limit(service_id, free_sms_fragment_limit) return service_id, None diff --git a/app/notify_client/email_branding_client.py b/app/notify_client/email_branding_client.py index 21c98d0e7..4af1d6f0c 100644 --- a/app/notify_client/email_branding_client.py +++ b/app/notify_client/email_branding_client.py @@ -17,6 +17,12 @@ class EmailBrandingClient(NotifyAdminAPIClient): brandings.sort(key=lambda branding: branding[sort_key].lower()) return brandings + def get_email_branding_id_for_domain(self, domain): + for branding in self.get_all_email_branding(): + if domain and branding.get('domain') == domain: + return branding['id'] + return None + def get_letter_email_branding(self): return self.get(url='/dvla_organisations') diff --git a/tests/app/main/views/test_add_service.py b/tests/app/main/views/test_add_service.py index 0ed602f9c..ba0c606a3 100644 --- a/tests/app/main/views/test_add_service.py +++ b/tests/app/main/views/test_add_service.py @@ -33,6 +33,7 @@ def test_should_add_service_and_redirect_to_tour_when_no_services( mock_get_services_with_no_services, api_user_active, mock_create_or_update_free_sms_fragment_limit, + mock_get_all_email_branding, ): response = logged_in_client.post( url_for('main.add_service'), @@ -84,7 +85,8 @@ def test_should_add_service_and_redirect_to_dashboard_when_existing_service( api_user_active, organisation_type, free_allowance, - mock_create_or_update_free_sms_fragment_limit + mock_create_or_update_free_sms_fragment_limit, + mock_get_all_email_branding, ): response = logged_in_client.post( url_for('main.add_service'), @@ -109,6 +111,43 @@ def test_should_add_service_and_redirect_to_dashboard_when_existing_service( assert response.location == url_for('main.service_dashboard', service_id=101, _external=True) +@pytest.mark.parametrize('email_address, expected_branding', [ + ('test@example.voa.gsi.gov.uk', '5'), + ('test@example.voa.gov.uk', '5'), + ('test@example.gov.uk', None), +]) +def test_should_lookup_branding_for_known_domain( + app_, + client_request, + active_user_with_permissions, + mock_create_service, + mock_get_services, + mock_update_service, + mock_create_or_update_free_sms_fragment_limit, + mock_get_all_email_branding, + email_address, + expected_branding, +): + active_user_with_permissions.email_address = email_address + client_request.login(active_user_with_permissions) + client_request.post( + 'main.add_service', + _data={ + 'name': 'testing the post', + 'organisation_type': 'central', + } + ) + mock_get_all_email_branding.assert_called_once_with() + assert mock_create_service.called is True + if expected_branding: + mock_update_service.assert_called_once_with( + 101, + email_branding=expected_branding, + ) + else: + assert mock_update_service.called is False + + def test_should_return_form_errors_when_service_name_is_empty( logged_in_client ): @@ -120,6 +159,7 @@ def test_should_return_form_errors_when_service_name_is_empty( def test_should_return_form_errors_with_duplicate_service_name_regardless_of_case( logged_in_client, mock_create_duplicate_service, + mock_get_all_email_branding, ): response = logged_in_client.post( url_for('main.add_service'), diff --git a/tests/conftest.py b/tests/conftest.py index 55012896c..5781d9ebc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2449,7 +2449,7 @@ def create_email_brandings(number_of_brandings, non_standard_values={}, shuffle= } for idx in range(1, number_of_brandings + 1)] for idx, row in enumerate(non_standard_values): - brandings[row['idx']].update(non_standard_values) + brandings[row['idx']].update(non_standard_values[idx]) if shuffle: brandings.insert(3, brandings.pop(4)) @@ -2464,7 +2464,7 @@ def mock_get_all_email_branding(mocker): {'idx': 1, 'colour': 'red'}, {'idx': 2, 'colour': 'orange'}, {'idx': 3, 'text': None}, - {'idx': 4, 'colour': 'blue'}, + {'idx': 4, 'colour': 'blue', 'domain': 'voa.gov.uk'}, ] shuffle = sort_key is None return create_email_brandings(5, non_standard_values=non_standard_values, shuffle=shuffle)