diff --git a/app/config.py b/app/config.py index 8ec4db186..77138ca16 100644 --- a/app/config.py +++ b/app/config.py @@ -38,7 +38,7 @@ class Config(object): NR_MONITOR_ON = settings and settings.monitor_mode COMMIT_HASH = getenv("COMMIT_HASH", "--------")[0:7] - GOVERNMENT_EMAIL_DOMAIN_NAMES = ["gov"] + GOVERNMENT_EMAIL_DOMAIN_NAMES = ["gov", "mil", "si.edu"] # Logging NOTIFY_LOG_LEVEL = getenv("NOTIFY_LOG_LEVEL", "INFO") diff --git a/app/main/views/register.py b/app/main/views/register.py index 2829d37bb..7f50c6a19 100644 --- a/app/main/views/register.py +++ b/app/main/views/register.py @@ -26,6 +26,7 @@ from app.main.views import sign_in from app.main.views.verify import activate_user from app.models.user import InvitedOrgUser, InvitedUser, User from app.utils import hide_from_search_engines, hilite +from app.utils.user import is_gov_user @main.route("/register", methods=["GET", "POST"]) @@ -147,6 +148,11 @@ def check_invited_user_email_address_matches_expected( flash("You cannot accept an invite for another person.") abort(403) + if not is_gov_user(user_email): + debug_msg("invited user has a non-government email address.") + flash("You must use a government email address.") + abort(403) + @main.route("/set-up-your-profile", methods=["GET", "POST"]) @hide_from_search_engines diff --git a/app/main/views/sign_in.py b/app/main/views/sign_in.py index d00d36892..9f089fe42 100644 --- a/app/main/views/sign_in.py +++ b/app/main/views/sign_in.py @@ -4,7 +4,16 @@ import uuid import jwt import requests -from flask import Response, current_app, redirect, render_template, request, url_for +from flask import ( + Response, + abort, + current_app, + flash, + redirect, + render_template, + request, + url_for, +) from flask_login import current_user from app import login_manager, user_api_client @@ -15,6 +24,7 @@ from app.models.user import User from app.utils import hide_from_search_engines from app.utils.login import is_safe_redirect_url from app.utils.time import is_less_than_days_ago +from app.utils.user import is_gov_user from notifications_utils.url_safe_token import generate_token @@ -88,6 +98,12 @@ def _do_login_dot_gov(): try: access_token = _get_access_token(code, state) user_email, user_uuid = _get_user_email_and_uuid(access_token) + if not is_gov_user(user_email): + current_app.logger.error( + "invited user has a non-government email address." + ) + flash("You must use a government email address.") + abort(403) redirect_url = request.args.get("next") user = user_api_client.get_user_by_uuid_or_email(user_uuid, user_email) diff --git a/app/utils/user.py b/app/utils/user.py index 668fcb646..a40de8558 100644 --- a/app/utils/user.py +++ b/app/utils/user.py @@ -4,7 +4,6 @@ from flask import abort, current_app from flask_login import current_user, login_required from app import config -from app.notify_client.organizations_api_client import organizations_client user_is_logged_in = login_required @@ -51,7 +50,7 @@ def user_is_platform_admin(f): def is_gov_user(email_address): return _email_address_ends_with( email_address, config.Config.GOVERNMENT_EMAIL_DOMAIN_NAMES - ) or _email_address_ends_with(email_address, organizations_client.get_domains()) + ) # or _email_address_ends_with(email_address, organizations_client.get_domains()) def _email_address_ends_with(email_address, known_domains): diff --git a/tests/app/main/views/test_register.py b/tests/app/main/views/test_register.py index 688ab1623..19d8c5a4b 100644 --- a/tests/app/main/views/test_register.py +++ b/tests/app/main/views/test_register.py @@ -148,7 +148,7 @@ def test_should_return_200_when_email_is_not_gov_uk( "email_address", [ "notfound@example.gsa.gov", - "example@lsquo.net", + "example@lsquo.si.edu", ], ) def test_should_add_user_details_to_session( @@ -401,6 +401,26 @@ def test_check_invited_user_email_address_doesnt_match_expected(mocker): mock_abort.assert_called_once_with(403) +def test_check_user_email_address_fails_if_not_government_address(mocker): + mock_flash = mocker.patch("app.main.views.register.flash") + mock_abort = mocker.patch("app.main.views.register.abort") + + check_invited_user_email_address_matches_expected( + "fake@fake.bogus", "Fake@Fake.BOGUS" + ) + mock_flash.assert_called_once_with("You must use a government email address.") + mock_abort.assert_called_once_with(403) + + +def test_check_user_email_address_succeeds_if_government_address(mocker): + mock_flash = mocker.patch("app.main.views.register.flash") + mock_abort = mocker.patch("app.main.views.register.abort") + + check_invited_user_email_address_matches_expected("fake@fake.mil", "Fake@Fake.MIL") + mock_flash.assert_not_called() + mock_abort.assert_not_called() + + def decode_invite_data(state): state = state.encode("utf8") state = base64.b64decode(state)