diff --git a/app/main/views/index.py b/app/main/views/index.py index 8d00305c7..7728cb325 100644 --- a/app/main/views/index.py +++ b/app/main/views/index.py @@ -41,7 +41,7 @@ def index(): current_app.config["SECRET_KEY"], current_app.config["DANGEROUS_SALT"], ) - state_key = f"login-nonce-{unquote(state)}" + state_key = f"login-state-{unquote(state)}" redis_client.set(state_key, state) # make and store the nonce diff --git a/app/main/views/register.py b/app/main/views/register.py index 8f2ae34a7..10422aefd 100644 --- a/app/main/views/register.py +++ b/app/main/views/register.py @@ -2,6 +2,7 @@ import base64 import json import uuid from datetime import datetime, timedelta +from urllib.parse import unquote from flask import ( abort, @@ -161,6 +162,13 @@ def set_up_your_profile(): debug_msg(f"Enter set_up_your_profile with request.args {request.args}") code = request.args.get("code") state = request.args.get("state") + + state_key = f"login-state-{unquote(state)}" + stored_state = redis_client.get(state_key).decode("utf8") + if state != stored_state: + current_app.logger.error(f"State Error: {state} != {stored_state}") + abort(403) + login_gov_error = request.args.get("error") if redis_client.get(f"invitedata-{state}") is None: diff --git a/app/main/views/sign_in.py b/app/main/views/sign_in.py index 34f5badb3..8ae3a9770 100644 --- a/app/main/views/sign_in.py +++ b/app/main/views/sign_in.py @@ -99,11 +99,6 @@ def _do_login_dot_gov(): # $ pragma: no cover # start login.gov code = request.args.get("code") state = request.args.get("state") - state_key = f"login-state-{unquote(state)}" - stored_state = redis_client.get(state_key).decode("utf8") - if state != stored_state: - current_app.logger.error(f"State Error: {state} != {stored_state}") - abort(403) login_gov_error = request.args.get("error") @@ -113,6 +108,11 @@ def _do_login_dot_gov(): # $ pragma: no cover ) raise Exception(f"Could not login with login.gov {login_gov_error}") elif code and state: + state_key = f"login-state-{unquote(state)}" + stored_state = redis_client.get(state_key).decode("utf8") + if state != stored_state: + current_app.logger.error(f"State Error: {state} != {stored_state}") + abort(403) # activate the user try: