diff --git a/app/main/views/add_service.py b/app/main/views/add_service.py index 4466114e2..1fed9c5ad 100644 --- a/app/main/views/add_service.py +++ b/app/main/views/add_service.py @@ -2,29 +2,12 @@ from flask import current_app, redirect, render_template, session, url_for from flask_login import login_required from notifications_python_client.errors import HTTPError -from app import ( - billing_api_client, - email_branding_client, - invite_api_client, - service_api_client, - user_api_client, -) +from app import billing_api_client, email_branding_client, service_api_client from app.main import main from app.main.forms import CreateServiceForm -from app.models.user import InvitedUser from app.utils import AgreementInfo, email_safe, user_is_gov_user -def _add_invited_user_to_service(invited_user): - invitation = InvitedUser(**invited_user) - # if invited user add to service and redirect to dashboard - user = user_api_client.get_user(session['user_id']) - service_id = invited_user['service'] - user_api_client.add_user_to_service(service_id, user.id, invitation.permissions) - invite_api_client.accept_invite(service_id, invitation.id) - return service_id - - def _create_service(service_name, organisation_type, email_from, form): free_sms_fragment_limit = current_app.config['DEFAULT_FREE_SMS_FRAGMENT_LIMITS'].get(organisation_type) email_branding = email_branding_client.get_email_branding_id_for_domain( @@ -70,11 +53,6 @@ def _create_example_template(service_id): @login_required @user_is_gov_user def add_service(): - invited_user = session.get('invited_user') - if invited_user: - service_id = _add_invited_user_to_service(invited_user) - return redirect(url_for('main.service_dashboard', service_id=service_id)) - form = CreateServiceForm() heading = 'About your service' diff --git a/app/main/views/register.py b/app/main/views/register.py index 4005e881c..209010409 100644 --- a/app/main/views/register.py +++ b/app/main/views/register.py @@ -68,7 +68,6 @@ def register_from_org_invite(): abort(400) _do_registration(form, send_email=False, send_sms=True, organisation_id=invited_org_user['organisation']) org_invite_api_client.accept_invite(invited_org_user['organisation'], invited_org_user['id']) - user_api_client.add_user_to_organisation(invited_org_user['organisation'], session['user_details']['id']) return redirect(url_for('main.verify')) return render_template('views/register-from-org-invite.html', invited_org_user=invited_org_user, form=form) diff --git a/app/main/views/verify.py b/app/main/views/verify.py index 14ac89470..f98dfaebe 100644 --- a/app/main/views/verify.py +++ b/app/main/views/verify.py @@ -16,6 +16,7 @@ from notifications_utils.url_safe_token import check_token from app import user_api_client from app.main import main from app.main.forms import TwoFactorForm +from app.models.user import InvitedUser from app.utils import redirect_to_sign_in @@ -70,10 +71,28 @@ def activate_user(user_id): user = user_api_client.get_user(user_id) # the user will have a new current_session_id set by the API - store it in the cookie for future requests session['current_session_id'] = user.current_session_id - organisation_id = session.get('organisation_id', None) + organisation_id = session.get('organisation_id') activated_user = user_api_client.activate_user(user) login_user(activated_user) + + invited_user = session.get('invited_user') + if invited_user: + service_id = _add_invited_user_to_service(invited_user) + return redirect(url_for('main.service_dashboard', service_id=service_id)) + + invited_org_user = session.get('invited_org_user') + if invited_org_user: + user_api_client.add_user_to_organisation(invited_org_user['organisation'], session['user_details']['id']) + if organisation_id: return redirect(url_for('main.organisation_dashboard', org_id=organisation_id)) else: return redirect(url_for('main.add_service', first='first')) + + +def _add_invited_user_to_service(invited_user): + invitation = InvitedUser(**invited_user) + user = user_api_client.get_user(session['user_id']) + service_id = invited_user['service'] + user_api_client.add_user_to_service(service_id, user.id, invitation.permissions) + return service_id diff --git a/tests/app/main/views/test_register.py b/tests/app/main/views/test_register.py index 39e59b5c4..279e53817 100644 --- a/tests/app/main/views/test_register.py +++ b/tests/app/main/views/test_register.py @@ -271,6 +271,7 @@ def test_register_from_invite_when_user_registers_in_another_browser( assert response.location == url_for('main.verify', _external=True) +@pytest.mark.parametrize('invite_email_address', ['gov-user@gov.uk', 'non-gov-user@example.com']) def test_register_from_email_auth_invite( client, sample_invite, @@ -281,8 +282,11 @@ def test_register_from_email_auth_invite( mock_send_verify_code, mock_accept_invite, mock_create_event, + mock_add_user_to_service, + invite_email_address, ): sample_invite['auth_type'] = 'email_auth' + sample_invite['email_address'] = invite_email_address with client.session_transaction() as session: session['invited_user'] = sample_invite assert not current_user.is_authenticated @@ -298,7 +302,7 @@ def test_register_from_email_auth_invite( resp = client.post(url_for('main.register_from_invite'), data=data) assert resp.status_code == 302 - assert resp.location == url_for('main.add_service', first='first', _external=True) + assert resp.location == url_for('main.service_dashboard', service_id=sample_invite['service'], _external=True) # doesn't send any 2fa code assert not mock_send_verify_email.called @@ -314,6 +318,7 @@ def test_register_from_email_auth_invite( mock_accept_invite.assert_called_once_with(sample_invite['service'], sample_invite['id']) # just logs them in assert current_user.is_authenticated + assert mock_add_user_to_service.called with client.session_transaction() as session: # invited user details are still there so they can get added to the service @@ -330,6 +335,7 @@ def test_can_register_email_auth_without_phone_number( mock_send_verify_code, mock_accept_invite, mock_create_event, + mock_add_user_to_service, ): sample_invite['auth_type'] = 'email_auth' with client.session_transaction() as session: @@ -346,7 +352,7 @@ def test_can_register_email_auth_without_phone_number( resp = client.post(url_for('main.register_from_invite'), data=data) assert resp.status_code == 302 - assert resp.location == url_for('main.add_service', first='first', _external=True) + assert resp.location == url_for('main.service_dashboard', service_id=sample_invite['service'], _external=True) mock_register_user.assert_called_once_with( ANY,