diff --git a/app/main/views/code_not_received.py b/app/main/views/code_not_received.py index 13ce5bfb2..cebb73b7e 100644 --- a/app/main/views/code_not_received.py +++ b/app/main/views/code_not_received.py @@ -8,19 +8,20 @@ from flask import ( from app import user_api_client from app.main import main from app.main.forms import TextNotReceivedForm +from app.utils import redirect_to_sign_in @main.route('/resend-email-verification') +@redirect_to_sign_in def resend_email_verification(): - # TODO there needs to be a way to regenerate a session id user = user_api_client.get_user_by_email(session['user_details']['email']) user_api_client.send_verify_email(user.id, user.email_address) return render_template('views/resend-email-verification.html', email=user.email_address) @main.route('/text-not-received', methods=['GET', 'POST']) +@redirect_to_sign_in def check_and_resend_text_code(): - # TODO there needs to be a way to regenerate a session id user = user_api_client.get_user_by_email(session['user_details']['email']) if user.state == 'active': @@ -38,8 +39,8 @@ def check_and_resend_text_code(): @main.route('/send-new-code', methods=['GET']) +@redirect_to_sign_in def check_and_resend_verification_code(): - # TODO there needs to be a way to generate a new session id user = user_api_client.get_user_by_email(session['user_details']['email']) user_api_client.send_verify_code(user.id, 'sms', user.mobile_number) if user.state == 'pending': diff --git a/app/main/views/two_factor.py b/app/main/views/two_factor.py index b8330a6ee..a728aea19 100644 --- a/app/main/views/two_factor.py +++ b/app/main/views/two_factor.py @@ -9,17 +9,14 @@ from flask import ( from flask_login import login_user, current_user from app.main import main from app.main.forms import TwoFactorForm -from app import service_api_client -from app import user_api_client +from app import service_api_client, user_api_client +from app.utils import redirect_to_sign_in @main.route('/two-factor', methods=['GET', 'POST']) +@redirect_to_sign_in def two_factor(): - # TODO handle user_email not in session - try: - user_id = session['user_details']['id'] - except KeyError: - return redirect(url_for('main.sign_in')) + user_id = session['user_details']['id'] def _check_code(code): return user_api_client.check_verify_code(user_id, code, "sms") diff --git a/app/main/views/verify.py b/app/main/views/verify.py index 71ee28ee0..fc9d48691 100644 --- a/app/main/views/verify.py +++ b/app/main/views/verify.py @@ -18,14 +18,14 @@ from notifications_utils.url_safe_token import check_token from app.main import main from app.main.forms import TwoFactorForm +from app.utils import redirect_to_sign_in from app import user_api_client @main.route('/verify', methods=['GET', 'POST']) +@redirect_to_sign_in def verify(): - # TODO there needs to be a way to regenerate a session id - # or handle gracefully. user_id = session['user_details']['id'] def _check_code(code): diff --git a/app/utils.py b/app/utils.py index 052c52981..242af78dc 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,9 +1,9 @@ import re import csv -from io import BytesIO, StringIO +from io import StringIO from os import path from functools import wraps -from flask import (abort, session, request, url_for) +from flask import (abort, session, request, redirect, url_for) import pyexcel import pyexcel.ext.io import pyexcel.ext.xls @@ -51,6 +51,16 @@ def user_has_permissions(*permissions, admin_override=False, any_=False): return wrap +def redirect_to_sign_in(f): + @wraps(f) + def wrapped(*args, **kwargs): + if 'user_details' not in session: + return redirect(url_for('main.sign_in')) + else: + return f(*args, **kwargs) + return wrapped + + def get_errors_for_csv(recipients, template_type): errors = [] diff --git a/tests/app/main/views/test_code_not_received.py b/tests/app/main/views/test_code_not_received.py index 54cea2585..271b6ee81 100644 --- a/tests/app/main/views/test_code_not_received.py +++ b/tests/app/main/views/test_code_not_received.py @@ -1,6 +1,5 @@ - +import pytest from flask import url_for - from bs4 import BeautifulSoup @@ -128,3 +127,16 @@ def test_check_and_redirect_to_verify_if_user_pending(app_, response = client.get(url_for('main.check_and_resend_verification_code')) assert response.status_code == 302 assert response.location == url_for('main.verify', _external=True) + + +@pytest.mark.parametrize('endpoint', [ + 'main.resend_email_verification', + 'main.check_and_resend_text_code', + 'main.check_and_resend_verification_code', +]) +def test_redirect_to_sign_in_if_not_logged_in(app_, endpoint): + with app_.test_request_context(), app_.test_client() as client: + response = client.get(url_for(endpoint)) + + assert response.location == url_for('main.sign_in', _external=True) + assert response.status_code == 302 diff --git a/tests/app/main/views/test_verify.py b/tests/app/main/views/test_verify.py index 778df90dd..2a7599890 100644 --- a/tests/app/main/views/test_verify.py +++ b/tests/app/main/views/test_verify.py @@ -147,3 +147,11 @@ def test_verify_email_redirects_to_sign_in_if_user_active(app_, assert page.h1.text == 'Sign in' flash_banner = page.find('div', class_='banner-dangerous').string.strip() assert flash_banner == "That verification link has expired." + + +def test_verify_redirects_to_sign_in_if_not_logged_in(app_): + with app_.test_request_context(), app_.test_client() as client: + response = client.get(url_for('main.verify')) + + assert response.location == url_for('main.sign_in', _external=True) + assert response.status_code == 302