diff --git a/app/main/views/index.py b/app/main/views/index.py
index 8b965991c..7728cb325 100644
--- a/app/main/views/index.py
+++ b/app/main/views/index.py
@@ -35,24 +35,24 @@ def index():
if current_user and current_user.is_authenticated:
return redirect(url_for("main.choose_account"))
- token = generate_token(
+ # make and store the state
+ state = generate_token(
str(request.remote_addr),
current_app.config["SECRET_KEY"],
current_app.config["DANGEROUS_SALT"],
)
- url = os.getenv("LOGIN_DOT_GOV_INITIAL_SIGNIN_URL")
- # handle unit tests
-
- current_app.logger.warning(f"############### {str(request.remote_addr)}")
+ state_key = f"login-state-{unquote(state)}"
+ redis_client.set(state_key, state)
+ # make and store the nonce
nonce = secrets.token_urlsafe()
+ nonce_key = f"login-nonce-{unquote(nonce)}"
+ redis_client.set(nonce_key, nonce)
- redis_key = f"login-nonce-{unquote(nonce)}"
- redis_client.set(redis_key, nonce)
-
+ url = os.getenv("LOGIN_DOT_GOV_INITIAL_SIGNIN_URL")
if url is not None:
url = url.replace("NONCE", nonce)
- url = url.replace("STATE", token)
+ url = url.replace("STATE", state)
return render_template(
"views/signedout.html",
sms_rate=CURRENT_SMS_RATE,
diff --git a/app/main/views/invites.py b/app/main/views/invites.py
index 49fb66f88..07f6b3ac8 100644
--- a/app/main/views/invites.py
+++ b/app/main/views/invites.py
@@ -17,14 +17,12 @@ def accept_invite(token):
and current_user.email_address.lower() != invited_user.email_address.lower()
):
message = Markup(
- """
- You’re signed in as {}.
+ f"""
+ You’re signed in as {current_user.email_address}.
This invite is for another email address.
- Sign out
+ Sign out
and click the link again to accept this invite.
- """.format(
- current_user.email_address, url_for("main.sign_out")
- )
+ """
)
flash(message=message)
diff --git a/app/main/views/register.py b/app/main/views/register.py
index cc6850055..d0ceb65a9 100644
--- a/app/main/views/register.py
+++ b/app/main/views/register.py
@@ -1,7 +1,7 @@
-import base64
import json
import uuid
from datetime import datetime, timedelta
+from urllib.parse import unquote
from flask import (
abort,
@@ -161,18 +161,29 @@ def set_up_your_profile():
debug_msg(f"Enter set_up_your_profile with request.args {request.args}")
code = request.args.get("code")
state = request.args.get("state")
+
+ state_key = f"login-state-{unquote(state)}"
+ stored_state = unquote(redis_client.get(state_key).decode("utf8"))
+ if state != stored_state:
+ current_app.logger.error(f"State Error: {state} != {stored_state}")
+ abort(403)
+
login_gov_error = request.args.get("error")
- if redis_client.get(f"invitedata-{state}") is None:
- access_token = sign_in._get_access_token(code, state)
+ user_email = redis_client.get(f"user_email-{state}")
+ user_uuid = redis_client.get(f"user_uuid-{state}")
+
+ new_user = user_email is None or user_uuid is None
+
+ if new_user: # invite path
+ access_token = sign_in._get_access_token(code)
debug_msg("Got the access token for login.gov")
user_email, user_uuid = sign_in._get_user_email_and_uuid(access_token)
debug_msg(
f"Got the user_email {user_email} and user_uuid {user_uuid} from login.gov"
)
- invite_data = state.encode("utf8")
- invite_data = base64.b64decode(invite_data)
+ invite_data = redis_client.get(f"invitedata-{state}")
invite_data = json.loads(invite_data)
debug_msg(f"final state {invite_data}")
invited_user_id = invite_data["invited_user_id"]
@@ -194,10 +205,7 @@ def set_up_your_profile():
form = SetupUserProfileForm()
- if (
- form.validate_on_submit()
- and redis_client.get(f"invitedata-{state}") is not None
- ):
+ if form.validate_on_submit() and not new_user:
invite_data, user_email, user_uuid, invited_user_email_address = (
get_invite_data_from_redis(state)
)
@@ -222,6 +230,7 @@ def set_up_your_profile():
activate_user(user["id"])
debug_msg("activated user")
usr = User.from_id(user["id"])
+
usr.add_to_service(
invite_data["service_id"],
invite_data["permissions"],
diff --git a/app/main/views/sign_in.py b/app/main/views/sign_in.py
index a326202f3..d948f459e 100644
--- a/app/main/views/sign_in.py
+++ b/app/main/views/sign_in.py
@@ -39,7 +39,7 @@ def _reformat_keystring(orig): # pragma: no cover
return new_keystring
-def _get_access_token(code, state): # pragma: no cover
+def _get_access_token(code): # pragma: no cover
client_id = os.getenv("LOGIN_DOT_GOV_CLIENT_ID")
access_token_url = os.getenv("LOGIN_DOT_GOV_ACCESS_TOKEN_URL")
keystring = os.getenv("LOGIN_PEM")
@@ -66,8 +66,8 @@ def _get_access_token(code, state): # pragma: no cover
response_json = response.json()
id_token = get_id_token(response_json)
nonce = id_token["nonce"]
- redis_key = f"login-nonce-{unquote(nonce)}"
- stored_nonce = redis_client.get(redis_key).decode("utf8")
+ nonce_key = f"login-nonce-{unquote(nonce)}"
+ stored_nonce = redis_client.get(nonce_key).decode("utf8")
if nonce != stored_nonce:
current_app.logger.error(f"Nonce Error: {nonce} != {stored_nonce}")
@@ -99,6 +99,7 @@ def _do_login_dot_gov(): # $ pragma: no cover
# start login.gov
code = request.args.get("code")
state = request.args.get("state")
+
login_gov_error = request.args.get("error")
if login_gov_error:
@@ -107,10 +108,15 @@ def _do_login_dot_gov(): # $ pragma: no cover
)
raise Exception(f"Could not login with login.gov {login_gov_error}")
elif code and state:
+ state_key = f"login-state-{unquote(state)}"
+ stored_state = unquote(redis_client.get(state_key).decode("utf8"))
+ if state != stored_state:
+ current_app.logger.error(f"State Error: {state} != {stored_state}")
+ abort(403)
# activate the user
try:
- access_token = _get_access_token(code, state)
+ access_token = _get_access_token(code)
user_email, user_uuid = _get_user_email_and_uuid(access_token)
if not is_gov_user(user_email):
current_app.logger.error(
@@ -203,21 +209,23 @@ def sign_in(): # pragma: no cover
return redirect(redirect_url)
return redirect(url_for("main.show_accounts_or_dashboard"))
- token = generate_token(
+ state = generate_token(
str(request.remote_addr),
current_app.config["SECRET_KEY"],
current_app.config["DANGEROUS_SALT"],
)
- url = os.getenv("LOGIN_DOT_GOV_INITIAL_SIGNIN_URL")
+ state_key = f"login-state-{unquote(state)}"
+ redis_client.set(state_key, state)
nonce = secrets.token_urlsafe()
- redis_key = f"-{unquote(nonce)}"
- redis_client.set(redis_key, nonce)
+ nonce_key = f"login-nonce-{unquote(nonce)}"
+ redis_client.set(nonce_key, nonce)
+ url = os.getenv("LOGIN_DOT_GOV_INITIAL_SIGNIN_URL")
# handle unit tests
if url is not None:
url = url.replace("NONCE", nonce)
- url = url.replace("STATE", token)
+ url = url.replace("STATE", state)
return render_template(
"views/signin.html",
diff --git a/app/notify_client/invite_api_client.py b/app/notify_client/invite_api_client.py
index 711cc1f55..39cf7dbce 100644
--- a/app/notify_client/invite_api_client.py
+++ b/app/notify_client/invite_api_client.py
@@ -1,12 +1,16 @@
+import json
import secrets
from urllib.parse import unquote
+from flask import current_app, request
+
from app import redis_client
from app.notify_client import NotifyAdminAPIClient, _attach_current_user, cache
from app.utils.user_permissions import (
all_ui_permissions,
translate_permissions_from_ui_to_db,
)
+from notifications_utils.url_safe_token import generate_token
class InviteApiClient(NotifyAdminAPIClient):
@@ -37,14 +41,32 @@ class InviteApiClient(NotifyAdminAPIClient):
}
data = _attach_current_user(data)
+ # make and store the state
+ state = generate_token(
+ str(request.remote_addr),
+ current_app.config["SECRET_KEY"],
+ current_app.config["DANGEROUS_SALT"],
+ )
+ state_key = f"login-state-{unquote(state)}"
+ redis_client.set(state_key, state)
+
# make and store the nonce
nonce = secrets.token_urlsafe()
- redis_key = f"login-nonce-{unquote(nonce)}"
- redis_client.set(f"{redis_key}", nonce) # save the nonce to redis.
+ nonce_key = f"login-nonce-{unquote(nonce)}"
+ redis_client.set(nonce_key, nonce) # save the nonce to redis.
+
data["nonce"] = nonce # This is passed to api for the invite url.
+ data["state"] = state # This is passed to api for the invite url.
resp = self.post(url=f"/service/{service_id}/invite", data=data)
- return resp["data"]
+
+ resp_data = resp["data"]
+ invite_data_key = f"invitedata-{unquote(state)}"
+ redis_invite_data = resp["invite"]
+ redis_invite_data = json.dumps(redis_invite_data)
+ redis_client.set(invite_data_key, redis_invite_data)
+
+ return resp_data
def get_invites_for_service(self, service_id):
return self.get(f"/service/{service_id}/invite")["data"]
@@ -75,7 +97,32 @@ class InviteApiClient(NotifyAdminAPIClient):
self.post(url=f"/service/{service_id}/invite/{invited_user_id}", data=data)
def resend_invite(self, service_id, invited_user_id):
- self.post(url=f"/service/{service_id}/invite/{invited_user_id}/resend", data={})
+ # make and store the state
+ state = generate_token(
+ str(request.remote_addr),
+ current_app.config["SECRET_KEY"],
+ current_app.config["DANGEROUS_SALT"],
+ )
+ state_key = f"login-state-{unquote(state)}"
+ redis_client.set(state_key, state)
+
+ # make and store the nonce
+ nonce = secrets.token_urlsafe()
+ nonce_key = f"login-nonce-{unquote(nonce)}"
+ redis_client.set(nonce_key, nonce)
+
+ data = {
+ "nonce": nonce,
+ "state": state,
+ }
+ resp = self.post(
+ url=f"/service/{service_id}/invite/{invited_user_id}/resend", data=data
+ )
+
+ invite_data_key = f"invitedata-{unquote(state)}"
+ redis_invite_data = resp["invite"]
+ redis_invite_data = json.dumps(redis_invite_data)
+ redis_client.set(invite_data_key, redis_invite_data)
@cache.delete("service-{service_id}")
@cache.delete("user-{invited_user_id}")
diff --git a/app/notify_client/user_api_client.py b/app/notify_client/user_api_client.py
index 01a3a78c9..3514a0d6e 100644
--- a/app/notify_client/user_api_client.py
+++ b/app/notify_client/user_api_client.py
@@ -168,7 +168,7 @@ class UserApiClient(NotifyAdminAPIClient):
@cache.delete("user-{user_id}")
def add_user_to_service(self, service_id, user_id, permissions, folder_permissions):
# permissions passed in are the combined UI permissions, not DB permissions
- endpoint = "/service/{}/users/{}".format(service_id, user_id)
+ endpoint = f"/service/{service_id}/users/{user_id}"
data = {
"permissions": [
{"permission": x}
diff --git a/tests/app/notify_client/test_invite_client.py b/tests/app/notify_client/test_invite_client.py
index a251d5367..2a6f5113d 100644
--- a/tests/app/notify_client/test_invite_client.py
+++ b/tests/app/notify_client/test_invite_client.py
@@ -1,5 +1,7 @@
from unittest.mock import ANY
+from flask import current_app
+
from app import invite_api_client
@@ -26,21 +28,35 @@ def test_client_creates_invite(
"auth_type",
"folder_permissions",
"nonce",
+ "state",
}
- )
+ ),
+ "invite": {},
},
)
mock_token_urlsafe = mocker.patch("secrets.token_urlsafe")
fake_nonce = "1234567890"
+ fake_state = "0987654321"
mock_token_urlsafe.return_value = fake_nonce
- invite_api_client.create_invite(
- "12345", "67890", "test@example.com", {"send_messages"}, "sms_auth", [fake_uuid]
+ mock_generate_token = mocker.patch(
+ "app.notify_client.invite_api_client.generate_token"
)
+ mock_generate_token.return_value = fake_state
+
+ with current_app.test_request_context("/whatever"):
+ invite_api_client.create_invite(
+ "12345",
+ "67890",
+ "test@example.com",
+ {"send_messages"},
+ "sms_auth",
+ [fake_uuid],
+ )
mock_post.assert_called_once_with(
- url="/service/{}/invite".format("67890"),
+ url=f"/service/{"67890"}/invite",
data={
"auth_type": "sms_auth",
"email_address": "test@example.com",
@@ -51,6 +67,7 @@ def test_client_creates_invite(
"invite_link_host": "http://localhost:6012",
"folder_permissions": [fake_uuid],
"nonce": fake_nonce,
+ "state": fake_state,
},
)