diff --git a/app/main/views/sign_in.py b/app/main/views/sign_in.py index f1dc7bd88..411c2c620 100644 --- a/app/main/views/sign_in.py +++ b/app/main/views/sign_in.py @@ -17,7 +17,7 @@ from flask import ( ) from flask_login import current_user -from app import login_manager, user_api_client +from app import login_manager, redis_client, user_api_client from app.main import main from app.main.views.index import error from app.main.views.verify import activate_user @@ -72,7 +72,11 @@ def _get_access_token(code, state): raise KeyError(f"'access_token' {response.json()}") from e id_token = jwt.decode(encoded_id_token, keystring, algorithms=["RS256"]) nonce = id_token["nonce"] - if nonce != os.getenv("TOKEN_NONCE"): + state = request.args.get("state") + redis_key = f"token-nonce-{state}" + token_nonce = redis_client.get(redis_key) + redis_client.delete(redis_key) + if nonce != token_nonce: login_manager.unauthorized() try: @@ -206,7 +210,8 @@ def sign_in(): ) url = os.getenv("LOGIN_DOT_GOV_INITIAL_SIGNIN_URL") nonce = secrets.token_urlsafe() - os.environ["TOKEN_NONCE"] = nonce + state = request.args.get("state") + redis_client.set(f"token-nonce-{state}", nonce) # handle unit tests if url is not None: url = url.replace("NONCE", nonce)