Localize notification_utils to the API

This changeset pulls in all of the notification_utils code directly into the API and removes it as an external dependency.  We are doing this to cut down on operational maintenance of the project and will begin removing parts of it no longer needed for the API.

Signed-off-by: Carlo Costino <carlo.costino@gsa.gov>
This commit is contained in:
Carlo Costino
2024-05-16 10:17:45 -04:00
parent 4cdf8b2cb2
commit 99edc88197
129 changed files with 49913 additions and 263 deletions

View File

View File

@@ -0,0 +1,55 @@
import requests
from flask import current_app
class AntivirusError(Exception):
def __init__(self, message=None, status_code=None):
self.message = message
self.status_code = status_code
@classmethod
def from_exception(cls, e):
try:
message = e.response.json()["error"]
status_code = e.response.status_code
except (TypeError, ValueError, AttributeError, KeyError):
message = "connection error"
status_code = 503
return cls(message, status_code)
class AntivirusClient:
def __init__(self, api_host=None, auth_token=None):
self.api_host = api_host
self.auth_token = auth_token
def init_app(self, app):
self.api_host = app.config["ANTIVIRUS_API_HOST"]
self.auth_token = app.config["ANTIVIRUS_API_KEY"]
def scan(self, document_stream):
try:
response = requests.post(
"{}/scan".format(self.api_host),
headers={
"Authorization": "Bearer {}".format(self.auth_token),
},
files={"document": document_stream},
)
response.raise_for_status()
except requests.RequestException as e:
error = AntivirusError.from_exception(e)
current_app.logger.warning(
"Notify Antivirus API request failed with error: {}".format(
error.message
)
)
raise error
finally:
document_stream.seek(0)
return response.json()["ok"]

View File

@@ -0,0 +1,86 @@
from base64 import urlsafe_b64encode
from json import dumps, loads
from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from itsdangerous import BadSignature, URLSafeSerializer
class EncryptionError(Exception):
pass
class SaltLengthError(Exception):
pass
class Encryption:
def init_app(self, app):
self._serializer = URLSafeSerializer(app.config.get("SECRET_KEY"))
self._salt = app.config.get("DANGEROUS_SALT")
self._password = app.config.get("SECRET_KEY").encode()
try:
self._shared_encryptor = Fernet(self._derive_key(self._salt))
except SaltLengthError as reason:
raise EncryptionError(
"DANGEROUS_SALT must be at least 16 bytes"
) from reason
def encrypt(self, thing_to_encrypt, salt=None):
"""Encrypt a string or object
thing_to_encrypt must be serializable as JSON
Returns a UTF-8 string
"""
serialized_bytes = dumps(thing_to_encrypt).encode("utf-8")
encrypted_bytes = self._encryptor(salt).encrypt(serialized_bytes)
return encrypted_bytes.decode("utf-8")
def decrypt(self, thing_to_decrypt, salt=None):
"""Decrypt a UTF-8 string or bytes.
Once decrypted, thing_to_decrypt must be deserializable from JSON.
"""
try:
return loads(self._encryptor(salt).decrypt(thing_to_decrypt))
except InvalidToken as reason:
raise EncryptionError from reason
def sign(self, thing_to_sign, salt=None):
return self._serializer.dumps(thing_to_sign, salt=(salt or self._salt))
def verify_signature(self, thing_to_verify, salt=None):
try:
return self._serializer.loads(thing_to_verify, salt=(salt or self._salt))
except BadSignature as reason:
raise EncryptionError from reason
def _encryptor(self, salt=None):
if salt is None:
return self._shared_encryptor
else:
try:
return Fernet(self._derive_key(salt))
except SaltLengthError as reason:
raise EncryptionError(
"Custom salt value must be at least 16 bytes"
) from reason
def _derive_key(self, salt):
"""Derive a key suitable for use within Fernet from the SECRET_KEY and salt
* For the salt to be secure, it must be 16 bytes or longer and randomly generated.
* 600_000 was chosen for the iterations because it is what OWASP recommends as
* of [February 2023](https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#pbkdf2)
* For more information, see https://cryptography.io/en/latest/hazmat/primitives/key-derivation-functions/#pbkdf2
* and https://cryptography.io/en/latest/fernet/#using-passwords-with-fernet
"""
salt_bytes = salt.encode()
if len(salt_bytes) < 16:
raise SaltLengthError
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(), length=32, salt=salt_bytes, iterations=600_000
)
return urlsafe_b64encode(kdf.derive(self._password))

View File

@@ -0,0 +1,13 @@
from datetime import datetime
from .request_cache import RequestCache # noqa: F401 (unused import)
def total_limit_cache_key(service_id):
return "{}-{}-{}".format(
str(service_id), datetime.utcnow().strftime("%Y-%m-%d"), "total-count"
)
def rate_limit_cache_key(service_id, api_key_type):
return "{}-{}".format(str(service_id), api_key_type)

View File

@@ -0,0 +1,184 @@
import numbers
import uuid
from time import time
from flask import current_app
from flask_redis import FlaskRedis
def prepare_value(val):
"""
Only bytes, strings and numbers (ints, longs and floats) are acceptable
for keys and values. Previously redis-py attempted to cast other types
to str() and store the result. This caused must confusion and frustration
when passing boolean values (cast to 'True' and 'False') or None values
(cast to 'None'). It is now the user's responsibility to cast all
key names and values to bytes, strings or numbers before passing the
value to redis-py.
"""
# things redis-py natively supports
if isinstance(
val,
(
bytes,
str,
numbers.Number,
),
):
return val
# things we know we can safely cast to string
elif isinstance(val, (uuid.UUID,)):
return str(val)
else:
raise ValueError("cannot cast {} to a string".format(type(val)))
class RedisClient:
redis_store = FlaskRedis()
active = False
scripts = {}
def init_app(self, app):
self.active = app.config.get("REDIS_ENABLED")
if self.active:
self.redis_store.init_app(app)
self.register_scripts()
def register_scripts(self):
# delete keys matching a pattern supplied as a parameter. Does so in batches of 5000 to prevent unpack from
# exceeding lua's stack limit, and also to prevent errors if no keys match the pattern.
# Inspired by https://gist.github.com/ddre54/0a4751676272e0da8186
self.scripts["delete-keys-by-pattern"] = self.redis_store.register_script(
"""
local keys = redis.call('keys', ARGV[1])
local deleted = 0
for i=1, #keys, 5000 do
deleted = deleted + redis.call('del', unpack(keys, i, math.min(i + 4999, #keys)))
end
return deleted
"""
)
def delete_by_pattern(self, pattern, raise_exception=False):
r"""
Deletes all keys matching a given pattern, and returns how many keys were deleted.
Pattern is defined as in the KEYS command: https://redis.io/commands/keys
* h?llo matches hello, hallo and hxllo
* h*llo matches hllo and heeeello
* h[ae]llo matches hello and hallo, but not hillo
* h[^e]llo matches hallo, hbllo, ... but not hello
* h[a-b]llo matches hallo and hbllo
Use \ to escape special characters if you want to match them verbatim
"""
if self.active:
try:
return self.scripts["delete-keys-by-pattern"](args=[pattern])
except Exception as e:
self.__handle_exception(
e, raise_exception, "delete-by-pattern", pattern
)
return 0
def exceeded_rate_limit(self, cache_key, limit, interval, raise_exception=False):
"""
Rate limiting.
- Uses Redis sorted sets
- Also uses redis "multi" which is abstracted into pipeline() by FlaskRedis/PyRedis
- Sends all commands to redis as a group to be executed atomically
Method:
(1) Add event, scored by timestamp (zadd). The score determines order in set.
(2) Use zremrangebyscore to delete all set members with a score between
- Earliest entry (lowest score == earliest timestamp) - represented as '-inf'
and
- Current timestamp minus the interval
- Leaves only relevant entries in the set (those between now and now - interval)
(3) Count the set
(4) If count > limit fail request
(5) Ensure we expire the set key to preserve space
Notes:
- Failed requests count. If over the limit and keep making requests you'll stay over the limit.
- The actual value in the set is just the timestamp, the same as the score. We don't store any requets details.
- return value of pipe.execute() is an array containing the outcome of each call.
- result[2] == outcome of pipe.zcard()
- If redis is inactive, or we get an exception, allow the request
:param cache_key:
:param limit: Number of requests permitted within interval
:param interval: Interval we measure requests in
:param raise_exception: Should throw exception
:return:
"""
cache_key = prepare_value(cache_key)
if self.active:
try:
pipe = self.redis_store.pipeline()
when = time()
pipe.zadd(cache_key, {when: when})
pipe.zremrangebyscore(cache_key, "-inf", when - interval)
pipe.zcard(cache_key)
pipe.expire(cache_key, interval)
result = pipe.execute()
return result[2] > limit
except Exception as e:
self.__handle_exception(
e, raise_exception, "rate-limit-pipeline", cache_key
)
return False
else:
return False
def raw_set(self, key, value, ex=None, px=None, nx=False, xx=False):
self.redis_store.set(key, value, ex, px, nx, xx)
def set(
self, key, value, ex=None, px=None, nx=False, xx=False, raise_exception=False
):
key = prepare_value(key)
value = prepare_value(value)
if self.active:
try:
self.redis_store.set(key, value, ex, px, nx, xx)
except Exception as e:
self.__handle_exception(e, raise_exception, "set", key)
def incr(self, key, raise_exception=False):
key = prepare_value(key)
if self.active:
try:
return self.redis_store.incr(key)
except Exception as e:
self.__handle_exception(e, raise_exception, "incr", key)
def raw_get(self, key):
return self.redis_store.get(key)
def get(self, key, raise_exception=False):
key = prepare_value(key)
if self.active:
try:
return self.redis_store.get(key)
except Exception as e:
self.__handle_exception(e, raise_exception, "get", key)
return None
def delete(self, *keys, raise_exception=False):
keys = [prepare_value(k) for k in keys]
if self.active:
try:
self.redis_store.delete(*keys)
except Exception as e:
self.__handle_exception(e, raise_exception, "delete", ", ".join(keys))
def __handle_exception(self, e, raise_exception, operation, key_name):
current_app.logger.exception(
"Redis error performing {} on {}".format(operation, key_name)
)
if raise_exception:
raise e

View File

@@ -0,0 +1,95 @@
import json
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import signature
class RequestCache:
DEFAULT_TTL = int(timedelta(days=7).total_seconds())
def __init__(self, redis_client):
self.redis_client = redis_client
@staticmethod
def _get_argument(argument_name, client_method, args, kwargs):
with suppress(KeyError):
return kwargs[argument_name]
with suppress(ValueError, IndexError):
argument_index = list(signature(client_method).parameters).index(
argument_name
)
return args[argument_index]
with suppress(KeyError):
return signature(client_method).parameters[argument_name].default
raise TypeError(
"{}() takes no argument called '{}'".format(
client_method.__name__, argument_name
)
)
@staticmethod
def _make_key(key_format, client_method, args, kwargs):
return key_format.format(
**{
argument_name: RequestCache._get_argument(
argument_name, client_method, args, kwargs
)
for argument_name in list(signature(client_method).parameters)
}
)
def set(self, key_format, *, ttl_in_seconds=DEFAULT_TTL):
def _set(client_method):
@wraps(client_method)
def new_client_method(*args, **kwargs):
redis_key = RequestCache._make_key(
key_format, client_method, args, kwargs
)
cached = self.redis_client.get(redis_key)
if cached:
return json.loads(cached.decode("utf-8"))
api_response = client_method(*args, **kwargs)
self.redis_client.set(
redis_key,
json.dumps(api_response),
ex=int(ttl_in_seconds),
)
return api_response
return new_client_method
return _set
def delete(self, key_format):
def _delete(client_method):
@wraps(client_method)
def new_client_method(*args, **kwargs):
try:
api_response = client_method(*args, **kwargs)
finally:
redis_key = self._make_key(key_format, client_method, args, kwargs)
self.redis_client.delete(redis_key)
return api_response
return new_client_method
return _delete
def delete_by_pattern(self, key_format):
def _delete(client_method):
@wraps(client_method)
def new_client_method(*args, **kwargs):
try:
api_response = client_method(*args, **kwargs)
finally:
redis_key = self._make_key(key_format, client_method, args, kwargs)
self.redis_client.delete_by_pattern(redis_key)
return api_response
return new_client_method
return _delete

View File

@@ -0,0 +1,150 @@
import requests
from flask import current_app
class ZendeskError(Exception):
def __init__(self, response):
self.response = response
class ZendeskClient:
# the account used to authenticate with. If no requester is provided, the ticket will come from this account.
NOTIFY_ZENDESK_EMAIL = "zd-api-notify@digital.cabinet-office.gov.uk"
ZENDESK_TICKET_URL = "https://govuk.zendesk.com/api/v2/tickets.json"
def __init__(self):
self.api_key = None
def init_app(self, app, *args, **kwargs):
self.api_key = app.config.get("ZENDESK_API_KEY")
def send_ticket_to_zendesk(self, ticket):
response = requests.post(
self.ZENDESK_TICKET_URL,
json=ticket.request_data,
auth=(f"{self.NOTIFY_ZENDESK_EMAIL}/token", self.api_key),
)
if response.status_code != 201:
current_app.logger.error(
f"Zendesk create ticket request failed with {response.status_code} '{response.json()}'"
)
raise ZendeskError(response)
ticket_id = response.json()["ticket"]["id"]
current_app.logger.info(f"Zendesk create ticket {ticket_id} succeeded")
class NotifySupportTicket:
PRIORITY_URGENT = "urgent"
PRIORITY_HIGH = "high"
PRIORITY_NORMAL = "normal"
PRIORITY_LOW = "low"
TAGS_P2 = "govuk_notify_support"
TAGS_P1 = "govuk_notify_emergency"
TYPE_PROBLEM = "problem"
TYPE_INCIDENT = "incident"
TYPE_QUESTION = "question"
TYPE_TASK = "task"
# Group: 3rd Line--Notify Support
NOTIFY_GROUP_ID = 360000036529
# Organization: GDS
NOTIFY_ORG_ID = 21891972
NOTIFY_TICKET_FORM_ID = 1900000284794
def __init__(
self,
subject,
message,
ticket_type,
p1=False,
user_name=None,
user_email=None,
requester_sees_message_content=True,
technical_ticket=False,
ticket_categories=None,
org_id=None,
org_type=None,
service_id=None,
email_ccs=None,
):
self.subject = subject
self.message = message
self.ticket_type = ticket_type
self.p1 = p1
self.user_name = user_name
self.user_email = user_email
self.requester_sees_message_content = requester_sees_message_content
self.technical_ticket = technical_ticket
self.ticket_categories = ticket_categories or []
self.org_id = org_id
self.org_type = org_type
self.service_id = service_id
self.email_ccs = email_ccs
@property
def request_data(self):
data = {
"ticket": {
"subject": self.subject,
"comment": {
"body": self.message,
"public": self.requester_sees_message_content,
},
"group_id": self.NOTIFY_GROUP_ID,
"organization_id": self.NOTIFY_ORG_ID,
"ticket_form_id": self.NOTIFY_TICKET_FORM_ID,
"priority": self.PRIORITY_URGENT if self.p1 else self.PRIORITY_NORMAL,
"tags": [self.TAGS_P1 if self.p1 else self.TAGS_P2],
"type": self.ticket_type,
"custom_fields": self._get_custom_fields(),
}
}
if self.email_ccs:
data["ticket"]["email_ccs"] = [
{"user_email": email, "action": "put"} for email in self.email_ccs
]
# if no requester provided, then the call came from within Notify 👻
if self.user_email:
data["ticket"]["requester"] = {
"email": self.user_email,
"name": self.user_name or "(no name supplied)",
}
return data
def _get_custom_fields(self):
technical_ticket_tag = (
f'notify_ticket_type_{"" if self.technical_ticket else "non_"}technical'
)
org_type_tag = f"notify_org_type_{self.org_type}" if self.org_type else None
return [
{
"id": "1900000744994",
"value": technical_ticket_tag,
}, # Notify Ticket type field
{
"id": "360022836500",
"value": self.ticket_categories,
}, # Notify Ticket category field
{
"id": "360022943959",
"value": self.org_id,
}, # Notify Organisation ID field
{
"id": "360022943979",
"value": org_type_tag,
}, # Notify Organisation type field
{
"id": "1900000745014",
"value": self.service_id,
}, # Notify Service ID field
]