diff --git a/app/main/views/register.py b/app/main/views/register.py index 19187a47c..a8de8d1e8 100644 --- a/app/main/views/register.py +++ b/app/main/views/register.py @@ -26,6 +26,7 @@ from app.main.views import sign_in from app.main.views.verify import activate_user from app.models.user import InvitedOrgUser, InvitedUser, User from app.utils import hide_from_search_engines, hilite +from app.utils.login import get_id_token from app.utils.user import is_gov_user @@ -165,6 +166,15 @@ def set_up_your_profile(): if redis_client.get(f"invitedata-{state}") is None: access_token = sign_in._get_access_token(code, state) + + request_json = request.json() + id_token = get_id_token(request_json) + nonce = id_token["nonce"] + stored_nonce = redis_client.get(f"invitenonce-{state}") + if nonce != stored_nonce: + current_app.logger.error(f"Nonce Error: {nonce} != {stored_nonce}") + abort(403) + debug_msg("Got the access token for login.gov") user_email, user_uuid = sign_in._get_user_email_and_uuid(access_token) debug_msg( diff --git a/app/main/views/sign_in.py b/app/main/views/sign_in.py index a4026b485..700b01a02 100644 --- a/app/main/views/sign_in.py +++ b/app/main/views/sign_in.py @@ -1,4 +1,3 @@ -# import json import os import secrets import time @@ -25,7 +24,7 @@ from app.main.views.index import error from app.main.views.verify import activate_user from app.models.user import User from app.utils import hide_from_search_engines -from app.utils.login import is_safe_redirect_url +from app.utils.login import get_id_token, is_safe_redirect_url from app.utils.time import is_less_than_days_ago from app.utils.user import is_gov_user from notifications_utils.url_safe_token import generate_token @@ -43,7 +42,6 @@ def _reformat_keystring(orig): # pragma: no cover def _get_access_token(code, state): # pragma: no cover client_id = os.getenv("LOGIN_DOT_GOV_CLIENT_ID") access_token_url = os.getenv("LOGIN_DOT_GOV_ACCESS_TOKEN_URL") - # certs_url = os.getenv("LOGIN_DOT_GOV_CERTS_URL") keystring = os.getenv("LOGIN_PEM") if " " in keystring: keystring = _reformat_keystring(keystring) @@ -66,38 +64,12 @@ def _get_access_token(code, state): # pragma: no cover response = requests.post(url, headers=headers) response_json = response.json() - - # TODO nonce check intermittently fails, investifix - # Presumably the nonce is not yet in the session when there - # is an invite involved? - - # try: - # encoded_id_token = response_json["id_token"] - # except KeyError as e: - # current_app.logger.exception(f"Error when getting id token {response_json}") - # raise KeyError(f"'access_token' {response.json()}") from e - - # Getting Login.gov signing keys for unpacking the id_token correctly. - # jwks = requests.get(certs_url).json() - # public_keys = { - # jwk["kid"]: { - # "key": jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(jwk)), - # "algo": jwk["alg"], - # } - # for jwk in jwks["keys"] - # } - # kid = jwt.get_unverified_header(encoded_id_token)["kid"] - # pub_key = public_keys[kid]["key"] - # algo = public_keys[kid]["algo"] - # id_token = jwt.decode( - # encoded_id_token, pub_key, audience=client_id, algorithms=[algo] - # ) - # nonce = id_token["nonce"] - - # saved_nonce = session.pop("nonce") - # if nonce != saved_nonce: - # current_app.logger.error(f"Nonce Error: {nonce} != {saved_nonce}") - # abort(403) + id_token = get_id_token(response_json) + nonce = id_token["nonce"] + stored_nonce = session.pop("nonce") + if nonce != stored_nonce: + current_app.logger.error(f"Nonce Error: {nonce} != {stored_nonce}") + abort(403) try: access_token = response_json["access_token"] diff --git a/app/notify_client/invite_api_client.py b/app/notify_client/invite_api_client.py index d410ceec5..25a69967b 100644 --- a/app/notify_client/invite_api_client.py +++ b/app/notify_client/invite_api_client.py @@ -1,3 +1,8 @@ +import base64 +import json +import secrets + +from app import redis_client from app.notify_client import NotifyAdminAPIClient, _attach_current_user, cache from app.utils.user_permissions import ( all_ui_permissions, @@ -32,6 +37,18 @@ class InviteApiClient(NotifyAdminAPIClient): "folder_permissions": folder_permissions, } data = _attach_current_user(data) + + # make the state variable to properly store the nonce. + # this matches the api code in app.service_invite.rest.get_user_data_url_safe() + state_data = json.dumps(data) + state_data = base64.b64encode(state_data.encode("utf8")) + state = state_data.decode("utf8") + + # make and store the nonce + nonce = secrets.token_urlsafe() + redis_client.set(f"invitenonce-{state}", nonce) # save the nonce to redis. + data["nonce"] = nonce # This is passed to api for the invite url. + resp = self.post(url=f"/service/{service_id}/invite", data=data) return resp["data"] diff --git a/app/utils/login.py b/app/utils/login.py index 2d060ac86..2c3dec108 100644 --- a/app/utils/login.py +++ b/app/utils/login.py @@ -1,6 +1,10 @@ +import json +import os from functools import wraps -from flask import redirect, request, session, url_for +import jwt +import requests +from flask import current_app, redirect, request, session, url_for from app.models.user import User from app.utils.time import is_less_than_days_ago @@ -57,3 +61,32 @@ def is_safe_redirect_url(target): redirect_url.scheme in ("http", "https") and host_url.netloc == redirect_url.netloc ) + + +def get_id_token(json_data): + """Decode and return the id_token.""" + client_id = os.getenv("LOGIN_DOT_GOV_CLIENT_ID") + certs_url = os.getenv("LOGIN_DOT_GOV_CERTS_URL") + + try: + encoded_id_token = json_data["id_token"] + except KeyError as e: + current_app.logger.exception(f"Error when getting id token {json_data}") + raise KeyError(f"'access_token' {request.json()}") from e + + # Getting Login.gov signing keys for unpacking the id_token correctly. + jwks = requests.get(certs_url, timeout=5).json() + public_keys = { + jwk["kid"]: { + "key": jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(jwk)), + "algo": jwk["alg"], + } + for jwk in jwks["keys"] + } + kid = jwt.get_unverified_header(encoded_id_token)["kid"] + pub_key = public_keys[kid]["key"] + algo = public_keys[kid]["algo"] + id_token = jwt.decode( + encoded_id_token, pub_key, audience=client_id, algorithms=[algo] + ) + return id_token