diff --git a/app/main/views/register.py b/app/main/views/register.py index e9e03a76a..4638eaeab 100644 --- a/app/main/views/register.py +++ b/app/main/views/register.py @@ -15,7 +15,7 @@ from flask import ( ) from flask_login import current_user -from app import user_api_client +from app import redis_client, user_api_client from app.main import main from app.main.forms import ( RegisterUserForm, @@ -113,24 +113,126 @@ def registration_continue(): raise Exception("Unexpected routing in registration_continue") +def get_invite_data_from_redis(state): + + invite_data = json.loads(redis_client.raw_get(f"invitedata-{state}")) + user_email = redis_client.raw_get(f"user_email-{state}").decode("utf8") + user_uuid = redis_client.raw_get(f"user_uuid-{state}").decode("utf8") + invited_user_email_address = redis_client.raw_get( + f"invited_user_email_address-{state}" + ).decode("utf8") + return invite_data, user_email, user_uuid, invited_user_email_address + + +def put_invite_data_in_redis( + state, invite_data, user_email, user_uuid, invited_user_email_address +): + ttl = 60 * 15 # 15 minutes + + redis_client.raw_set(f"invitedata-{state}", json.dumps(invite_data), ex=ttl) + redis_client.raw_set(f"user_email-{state}", user_email, ex=ttl) + redis_client.raw_set(f"user_uuid-{state}", user_uuid, ex=ttl) + redis_client.raw_set( + f"invited_user_email_address-{state}", + invited_user_email_address, + ex=ttl, + ) + + +def check_invited_user_email_address_matches_expected( + user_email, invited_user_email_address +): + if user_email.lower() != invited_user_email_address.lower(): + debug_msg("invited user email did not match expected email, abort(403)") + flash("You cannot accept an invite for another person.") + abort(403) + + @main.route("/set-up-your-profile", methods=["GET", "POST"]) @hide_from_search_engines def set_up_your_profile(): + debug_msg(f"Enter set_up_your_profile with request.args {request.args}") + code = request.args.get("code") + state = request.args.get("state") + login_gov_error = request.args.get("error") + + if redis_client.raw_get(f"invitedata-{state}") is None: + access_token = sign_in._get_access_token(code, state) + debug_msg("Got the access token for login.gov") + user_email, user_uuid = sign_in._get_user_email_and_uuid(access_token) + debug_msg( + f"Got the user_email {user_email} and user_uuid {user_uuid} from login.gov" + ) + invite_data = state.encode("utf8") + invite_data = base64.b64decode(invite_data) + invite_data = json.loads(invite_data) + debug_msg(f"final state {invite_data}") + invited_user_id = invite_data["invited_user_id"] + invited_user_email_address = get_invited_user_email_address(invited_user_id) + debug_msg(f"email address from the invite_date is {invited_user_email_address}") + check_invited_user_email_address_matches_expected( + user_email, invited_user_email_address + ) + + invited_user_accept_invite(invited_user_id) + debug_msg( + f"accepted invite user {invited_user_email_address} to service {invite_data['service_id']}" + ) + # We need to avoid taking a second trip through the login.gov code because we cannot pull the + # access token twice. So once we retrieve these values, let's park them in redis for 15 minutes + put_invite_data_in_redis( + state, invite_data, user_email, user_uuid, invited_user_email_address + ) + form = SetupUserProfileForm() - if form.validate_on_submit(): - # start login.gov - code = request.args.get("code") - state = request.args.get("state") - login_gov_error = request.args.get("error") - if code and state: - return _handle_login_dot_gov_invite(code, state, form) - elif login_gov_error: - current_app.logger.error(f"login.gov error: {login_gov_error}") - raise Exception(f"Could not login with login.gov {login_gov_error}") - # end login.gov + if ( + form.validate_on_submit() + and redis_client.raw_get(f"invitedata-{state}") is not None + ): + invite_data, user_email, user_uuid, invited_user_email_address = ( + get_invite_data_from_redis(state) + ) + # create or update the user + user = user_api_client.get_user_by_uuid_or_email(user_uuid, user_email) + if user is None: + user = User.register( + name=form.name.data, + email_address=user_email, + mobile_number=form.mobile_number.data, + password=str(uuid.uuid4()), + auth_type="sms_auth", + ) + debug_msg(f"registered user {form.name.data} with email {user_email}") + else: + user.update(mobile_number=form.mobile_number.data, name=form.name.data) + debug_msg(f"updated user {form.name.data}") + + # activate the user + user = user_api_client.get_user_by_uuid_or_email(user_uuid, user_email) + activate_user(user["id"]) + debug_msg("activated user") + usr = User.from_id(user["id"]) + usr.add_to_service( + invite_data["service_id"], + invite_data["permissions"], + invite_data["folder_permissions"], + invite_data["from_user_id"], + ) + debug_msg( + f"Added user {usr.email_address} to service {invite_data['service_id']}" + ) + return redirect(url_for("main.show_accounts_or_dashboard")) + + elif login_gov_error: + current_app.logger.error(f"login.gov error: {login_gov_error}") + abort(403) + + # we take two trips through this method, but should only hit this + # line on the first trip. On the second trip, we should get redirected + # to the accounts page because we have successfully registered. return render_template("views/set-up-your-profile.html", form=form) @@ -150,60 +252,3 @@ def invited_user_accept_invite(invited_user_id): def debug_msg(msg): current_app.logger.debug(hilite(msg)) - - -def _handle_login_dot_gov_invite(code, state, form): - debug_msg(f"enter _handle_login_dot_gov_invite with code {code} state {state}") - access_token = sign_in._get_access_token(code, state) - debug_msg("Got the access token for login.gov") - user_email, user_uuid = sign_in._get_user_email_and_uuid(access_token) - debug_msg( - f"Got the user_email {user_email} and user_uuid {user_uuid} from login.gov" - ) - debug_msg(f"raw state {state}") - invite_data = state.encode("utf8") - debug_msg(f"utf8 encoded state {invite_data}") - invite_data = base64.b64decode(invite_data) - debug_msg(f"b64 decoded state {invite_data}") - invite_data = json.loads(invite_data) - debug_msg(f"final state {invite_data}") - invited_user_id = invite_data["invited_user_id"] - invited_user_email_address = get_invited_user_email_address(invited_user_id) - debug_msg(f"email address from the invite_date is {invited_user_email_address}") - if user_email.lower() != invited_user_email_address.lower(): - debug_msg("invited user email did not match expected email, abort(403)") - flash("You cannot accept an invite for another person.") - session.pop("invited_user_id", None) - abort(403) - else: - invited_user_accept_invite(invited_user_id) - debug_msg( - f"invited user {invited_user_email_address} to service {invite_data['service_id']}" - ) - debug_msg("accepted invite") - user = user_api_client.get_user_by_uuid_or_email(user_uuid, user_email) - if user is None: - user = User.register( - name=form.name.data, - email_address=user_email, - mobile_number=form.mobile_number.data, - password=str(uuid.uuid4()), - auth_type="sms_auth", - ) - debug_msg(f"registered user {form.name.data} with email {user_email}") - - # activate the user - user = user_api_client.get_user_by_uuid_or_email(user_uuid, user_email) - activate_user(user["id"]) - debug_msg("activated user") - usr = User.from_id(user["id"]) - usr.add_to_service( - invite_data["service_id"], - invite_data["permissions"], - invite_data["folder_permissions"], - invite_data["from_user_id"], - ) - debug_msg( - f"Added user {usr.email_address} to service {invite_data['service_id']}" - ) - return redirect(url_for("main.show_accounts_or_dashboard")) diff --git a/tests/app/main/views/test_register.py b/tests/app/main/views/test_register.py index cf4fba5a2..688ab1623 100644 --- a/tests/app/main/views/test_register.py +++ b/tests/app/main/views/test_register.py @@ -5,8 +5,7 @@ from unittest.mock import ANY import pytest from flask import url_for -from app.main.forms import RegisterUserForm -from app.main.views.register import _handle_login_dot_gov_invite +from app.main.views.register import check_invited_user_email_address_matches_expected from app.models.user import User from tests.conftest import normalize_spaces @@ -382,85 +381,26 @@ def test_cannot_register_with_sms_auth_and_missing_mobile_number( assert err.attrs["data-error-label"] == "mobile_number" -def test_handle_login_dot_gov_invite_bad_email(client_request, mocker): - - mocker.patch( - "app.main.views.register.sign_in._get_access_token", - return_value="access token", - ) - - mocker.patch( - "app.main.views.register.sign_in._get_user_email_and_uuid", - return_value=["fake@fake.gov", "12345"], - ) - - mocker.patch( - "app.main.views.register.get_invited_user_email_address", - return_value="boo@fake.gov", - ) - +def test_check_invited_user_email_address_matches_expected(mocker): mock_flash = mocker.patch("app.main.views.register.flash") - mock_abort = mocker.patch("app.main.views.register.abort") - mocker.patch("app.main.views.register.invited_user_accept_invite") + check_invited_user_email_address_matches_expected("fake@fake.gov", "Fake@Fake.GOV") + mock_flash.assert_not_called() + mock_abort.assert_not_called() - invite_data = {"service_id": "service", "invited_user_id": "invited_user"} - invite_data = json.dumps(invite_data) - invite_data = invite_data.encode("utf8") - invite_data = base64.b64encode(invite_data) - invite_data = invite_data.decode("utf8") - _handle_login_dot_gov_invite("code", invite_data, RegisterUserForm()) + +def test_check_invited_user_email_address_doesnt_match_expected(mocker): + mock_flash = mocker.patch("app.main.views.register.flash") + mock_abort = mocker.patch("app.main.views.register.abort") + + check_invited_user_email_address_matches_expected("real@fake.gov", "Fake@Fake.GOV") mock_flash.assert_called_once_with( "You cannot accept an invite for another person." ) mock_abort.assert_called_once_with(403) -def test_handle_login_dot_gov_invite_good_email(client_request, mocker): - - mocker.patch( - "app.main.views.register.sign_in._get_access_token", - return_value="access token", - ) - - mocker.patch( - "app.main.views.register.sign_in._get_user_email_and_uuid", - return_value=["fake@fake.gov", "12345"], - ) - - mocker.patch( - "app.main.views.register.get_invited_user_email_address", - return_value="fake@fake.gov", - ) - - mocker.patch( - "app.main.views.register.user_api_client.get_user_by_uuid_or_email", - return_value={"id": "abc"}, - ) - - mock_user = mocker.patch( - "app.main.views.register.User.add_to_service", - ) - - mock_accept = mocker.patch("app.main.views.register.invited_user_accept_invite") - - invite_data = { - "service_id": "service", - "invited_user_id": "invited_user", - "permissions": ["manage_everything"], - "folder_permissions": [], - "from_user_id": "xyz", - } - invite_data = json.dumps(invite_data) - invite_data = invite_data.encode("utf8") - invite_data = base64.b64encode(invite_data) - invite_data = invite_data.decode("utf8") - _handle_login_dot_gov_invite("code", invite_data, RegisterUserForm()) - mock_accept.assert_called_once() - mock_user.assert_called_once_with("service", ["manage_everything"], [], "xyz") - - def decode_invite_data(state): state = state.encode("utf8") state = base64.b64decode(state)