diff --git a/app/main/views/new_password.py b/app/main/views/new_password.py index 7b3c701f9..2f35d9ddf 100644 --- a/app/main/views/new_password.py +++ b/app/main/views/new_password.py @@ -5,6 +5,7 @@ from flask import ( flash, redirect, render_template, + request, session, url_for, ) @@ -46,6 +47,6 @@ def new_password(token): else: # send user a 2fa sms code user.send_verify_code() - return redirect(url_for('main.two_factor')) + return redirect(url_for('main.two_factor', next=request.args.get('next'))) else: return render_template('views/new-password.html', token=token, form=form, user=user) diff --git a/tests/app/main/views/test_new_password.py b/tests/app/main/views/test_new_password.py index 80359bcd7..e7641438d 100644 --- a/tests/app/main/views/test_new_password.py +++ b/tests/app/main/views/test_new_password.py @@ -1,6 +1,7 @@ import json from datetime import datetime +import pytest from flask import url_for from itsdangerous import SignatureExpired from notifications_utils.url_safe_token import generate_token @@ -36,21 +37,26 @@ def test_should_return_404_when_email_address_does_not_exist( assert response.status_code == 404 +@pytest.mark.parametrize('redirect_url', [ + None, + 'blob', +]) def test_should_redirect_to_two_factor_when_password_reset_is_successful( app_, client, mock_get_user_by_email_request_password_reset, mock_login, mock_send_verify_code, - mock_reset_failed_login_count + mock_reset_failed_login_count, + redirect_url ): user = mock_get_user_by_email_request_password_reset.return_value data = json.dumps({'email': user['email_address'], 'created_at': str(datetime.utcnow())}) token = generate_token(data, app_.config['SECRET_KEY'], app_.config['DANGEROUS_SALT']) - response = client.post(url_for_endpoint_with_token('.new_password', token=token), + response = client.post(url_for_endpoint_with_token('.new_password', token=token, next=redirect_url), data={'new_password': 'a-new_password'}) assert response.status_code == 302 - assert response.location == url_for('.two_factor', _external=True) + assert response.location == url_for('.two_factor', _external=True, next=redirect_url) mock_get_user_by_email_request_password_reset.assert_called_once_with(user['email_address']) diff --git a/tests/conftest.py b/tests/conftest.py index 9015c7cf0..a2c2c7121 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3646,9 +3646,9 @@ def mock_create_event(mocker): return mocker.patch('app.events_api_client.create_event', side_effect=_add_event) -def url_for_endpoint_with_token(endpoint, token): +def url_for_endpoint_with_token(endpoint, token, next=None): token = token.replace('%2E', '.') - return url_for(endpoint, token=token) + return url_for(endpoint, token=token, next=next) @pytest.fixture