mirror of
https://github.com/GSA/notifications-api.git
synced 2026-02-04 18:31:13 -05:00
replace m2crypto with oscrypto
This commit is contained in:
@@ -10,8 +10,7 @@ from flask import Blueprint, current_app, json, jsonify, request
|
|||||||
from sqlalchemy.orm.exc import NoResultFound
|
from sqlalchemy.orm.exc import NoResultFound
|
||||||
|
|
||||||
from app import notify_celery, statsd_client, redis_store
|
from app import notify_celery, statsd_client, redis_store
|
||||||
# from app.celery.validate_sns import valid_sns_message
|
from app.celery.validate_sns import validate_sns_message
|
||||||
import validatesns
|
|
||||||
from app.config import QueueNames
|
from app.config import QueueNames
|
||||||
from app.dao import notifications_dao
|
from app.dao import notifications_dao
|
||||||
from app.errors import InvalidRequest, register_errors
|
from app.errors import InvalidRequest, register_errors
|
||||||
@@ -44,13 +43,6 @@ def verify_message_type(message_type: str):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
raise InvalidMessageTypeException(f'{message_type} is not a valid message type.')
|
raise InvalidMessageTypeException(f'{message_type} is not a valid message type.')
|
||||||
|
|
||||||
def get_certificate(url):
|
|
||||||
res = redis_store.get(url)
|
|
||||||
if res is not None:
|
|
||||||
return res
|
|
||||||
res = requests.get(url).content
|
|
||||||
redis_store.set(url, res, ex=60 * 60) # 60 minutes
|
|
||||||
return res
|
|
||||||
|
|
||||||
# 400 counts as a permanent failure so SNS will not retry.
|
# 400 counts as a permanent failure so SNS will not retry.
|
||||||
# 500 counts as a failed delivery attempt so SNS will retry.
|
# 500 counts as a failed delivery attempt so SNS will retry.
|
||||||
@@ -73,27 +65,12 @@ def sns_callback_handler():
|
|||||||
current_app.logger.exception(f"Response headers: {request.headers}\nResponse data: {request.data}")
|
current_app.logger.exception(f"Response headers: {request.headers}\nResponse data: {request.data}")
|
||||||
raise InvalidRequest("SES-SNS callback failed: invalid JSON given", 400)
|
raise InvalidRequest("SES-SNS callback failed: invalid JSON given", 400)
|
||||||
|
|
||||||
current_app.logger.info(f"Message type: {message_type}\nResponse data: {message}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# AWS sends SigningCertURL if sending to a webhook, but SigningCertUrl if sending to a Lambda function
|
validate_sns_message(message)
|
||||||
message["SigningCertURL"] = message["SigningCertURL"] if "SigningCertURL" in message else message["SigningCertUrl"]
|
|
||||||
# Some SNS messages now contain "Subject": null, which is not handled by the validatesns library
|
|
||||||
if "Subject" in message and message["Subject"] == None:
|
|
||||||
message.pop("Subject")
|
|
||||||
validatesns.validate(message, get_certificate=get_certificate, max_age=DEFAULT_MAX_AGE)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"SES-SNS callback failed: validation failed! Response headers: {request.headers}\nResponse data: {request.data}\nError: Signature validation failed with error {err} and traceback {traceback.format_exc()}")
|
current_app.logger.error(f"SES-SNS callback failed: validation failed! Response headers: {request.headers}\nResponse data: {request.data}\nError: Signature validation failed with error {err}")
|
||||||
raise InvalidRequest("SES-SNS callback failed: validation failed", 400)
|
raise InvalidRequest("SES-SNS callback failed: validation failed", 400)
|
||||||
|
|
||||||
# try:
|
|
||||||
# if valid_sns_message(message) == False:
|
|
||||||
# current_app.logger.error(f"SES-SNS callback failed: validation failed! Response headers: {request.headers}\nResponse data: {request.data}\nError: Signature validation failed.")
|
|
||||||
# raise InvalidRequest("SES-SNS callback failed: validation failed", 400)
|
|
||||||
# except Exception as e:
|
|
||||||
# current_app.logger.exception(f"SES-SNS callback failed: validation failed! Response headers: {request.headers}\nResponse data: {request.data}\nError: {e}")
|
|
||||||
# raise InvalidRequest("SES-SNS callback failed: validation failed", 400)
|
|
||||||
|
|
||||||
if message.get('Type') == 'SubscriptionConfirmation':
|
if message.get('Type') == 'SubscriptionConfirmation':
|
||||||
url = message.get('SubscribeUrl') if 'SubscribeUrl' in message else message.get('SubscribeURL')
|
url = message.get('SubscribeUrl') if 'SubscribeUrl' in message else message.get('SubscribeURL')
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
|
|||||||
@@ -3,11 +3,14 @@ import re
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from M2Crypto import X509
|
import oscrypto.asymmetric
|
||||||
|
import oscrypto.errors
|
||||||
|
|
||||||
from app import redis_store
|
from app import redis_store
|
||||||
from app.config import Config
|
from app.config import Config
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
USE_CACHE = True
|
USE_CACHE = True
|
||||||
VALIDATE_ARN = True
|
VALIDATE_ARN = True
|
||||||
VALID_SNS_TOPICS = Config.VALID_SNS_TOPICS
|
VALID_SNS_TOPICS = Config.VALID_SNS_TOPICS
|
||||||
@@ -18,6 +21,11 @@ _cert_url_re = re.compile(
|
|||||||
r'sns\.([a-z]{1,3}-[a-z]+-[0-9]{1,2})\.amazonaws\.com',
|
r'sns\.([a-z]{1,3}-[a-z]+-[0-9]{1,2})\.amazonaws\.com',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ValidationError(Exception):
|
||||||
|
"""
|
||||||
|
ValidationError. Raised when a message fails integrity checks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_certificate(url):
|
def get_certificate(url):
|
||||||
if USE_CACHE:
|
if USE_CACHE:
|
||||||
@@ -31,80 +39,75 @@ def get_certificate(url):
|
|||||||
return requests.get(url).text
|
return requests.get(url).text
|
||||||
|
|
||||||
|
|
||||||
def valid_sns_message(sns_payload):
|
def validate_arn(sns_payload):
|
||||||
"""
|
|
||||||
Adapted from the solution posted at
|
|
||||||
https://github.com/boto/boto3/issues/2508#issuecomment-992931814
|
|
||||||
"""
|
|
||||||
if not isinstance(sns_payload, dict):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Amazon SNS currently supports signature version 1.
|
|
||||||
if sns_payload.get('SignatureVersion') != '1':
|
|
||||||
return False
|
|
||||||
|
|
||||||
if VALIDATE_ARN:
|
if VALIDATE_ARN:
|
||||||
arn = sns_payload.get('TopicArn')
|
arn = sns_payload.get('TopicArn')
|
||||||
topic_name = arn.split(':')[5]
|
topic_name = arn.split(':')[5]
|
||||||
if topic_name not in VALID_SNS_TOPICS:
|
if topic_name not in VALID_SNS_TOPICS:
|
||||||
return False
|
raise ValidationError("Invalid Topic Name")
|
||||||
|
|
||||||
|
|
||||||
|
def get_string_to_sign(sns_payload):
|
||||||
payload_type = sns_payload.get('Type')
|
payload_type = sns_payload.get('Type')
|
||||||
if payload_type in ['SubscriptionConfirmation', 'UnsubscribeConfirmation']:
|
if payload_type in ['SubscriptionConfirmation', 'UnsubscribeConfirmation']:
|
||||||
fields = ['Message', 'MessageId', 'SubscribeURL', 'Timestamp', 'Token', 'TopicArn', 'Type']
|
fields = ['Message', 'MessageId', 'SubscribeURL', 'Timestamp', 'Token', 'TopicArn', 'Type']
|
||||||
elif payload_type == 'Notification':
|
elif payload_type == 'Notification':
|
||||||
fields = ['Message', 'MessageId', 'Subject', 'Timestamp', 'TopicArn', 'Type']
|
fields = ['Message', 'MessageId', 'Subject', 'Timestamp', 'TopicArn', 'Type']
|
||||||
else:
|
else:
|
||||||
return False
|
raise ValidationError("Unexpected Message Type")
|
||||||
|
|
||||||
# Build the string to be signed.
|
|
||||||
string_to_sign = ''
|
string_to_sign = ''
|
||||||
for field in fields:
|
for field in fields:
|
||||||
field_value = sns_payload.get(field)
|
field_value = sns_payload.get(field)
|
||||||
if not isinstance(field_value, str):
|
if not isinstance(field_value, str):
|
||||||
return False
|
if field == 'Subject' and field_value == None:
|
||||||
|
continue
|
||||||
|
raise ValidationError(f"In {field}, found non-string value: {field_value}")
|
||||||
string_to_sign += field + '\n' + field_value + '\n'
|
string_to_sign += field + '\n' + field_value + '\n'
|
||||||
|
if isinstance(string_to_sign, six.text_type):
|
||||||
|
string_to_sign = string_to_sign.encode()
|
||||||
|
return string_to_sign
|
||||||
|
|
||||||
# Get the signature
|
|
||||||
try:
|
def validate_sns_message(sns_payload):
|
||||||
decoded_signature = base64.b64decode(sns_payload.get('Signature'))
|
"""
|
||||||
except (TypeError, ValueError):
|
Adapted from the solution posted at
|
||||||
return False
|
https://github.com/boto/boto3/issues/2508#issuecomment-992931814
|
||||||
|
"""
|
||||||
|
if not isinstance(sns_payload, dict):
|
||||||
|
raise ValidationError("Unexpected message type {!r}".format(type(sns_payload).__name__))
|
||||||
|
|
||||||
|
# Amazon SNS currently supports signature version 1.
|
||||||
|
if sns_payload.get('SignatureVersion') != '1':
|
||||||
|
raise ValidationError("Wrong Signature Version (expected 1)")
|
||||||
|
|
||||||
|
validate_arn(sns_payload)
|
||||||
|
|
||||||
|
string_to_sign = get_string_to_sign(sns_payload)
|
||||||
|
|
||||||
# Key signing cert url via Lambda and via webhook are slightly different
|
# Key signing cert url via Lambda and via webhook are slightly different
|
||||||
signing_cert_url = sns_payload.get('SigningCertUrl') if 'SigningCertUrl' in sns_payload else sns_payload.get('SigningCertURL')
|
signing_cert_url = sns_payload.get('SigningCertUrl') if 'SigningCertUrl' in sns_payload else sns_payload.get('SigningCertURL')
|
||||||
if not isinstance(signing_cert_url, str):
|
if not isinstance(signing_cert_url, str):
|
||||||
return False
|
raise ValidationError("Signing cert url must be a string")
|
||||||
cert_scheme, cert_netloc, *_ = urlparse(signing_cert_url)
|
cert_scheme, cert_netloc, *_ = urlparse(signing_cert_url)
|
||||||
if cert_scheme != 'https' or not re.match(_cert_url_re, cert_netloc):
|
if cert_scheme != 'https' or not re.match(_cert_url_re, cert_netloc):
|
||||||
# The cert doesn't seem to be from AWS
|
raise ValidationError("Cert does not appear to be from AWS")
|
||||||
return False
|
|
||||||
certificate = _signing_cert_cache.get(signing_cert_url)
|
certificate = _signing_cert_cache.get(signing_cert_url)
|
||||||
if certificate is None:
|
if certificate is None:
|
||||||
certificate = X509.load_cert_string(get_certificate(signing_cert_url))
|
certificate = get_certificate(signing_cert_url)
|
||||||
_signing_cert_cache[signing_cert_url] = certificate
|
if isinstance(certificate, six.text_type):
|
||||||
|
certificate = certificate.encode()
|
||||||
|
|
||||||
if certificate.get_subject().as_text() != 'CN=sns.amazonaws.com':
|
signature = base64.b64decode(sns_payload["Signature"])
|
||||||
return False
|
|
||||||
|
|
||||||
# Extract the public key.
|
try:
|
||||||
public_key = certificate.get_pubkey()
|
oscrypto.asymmetric.rsa_pkcs1v15_verify(
|
||||||
|
oscrypto.asymmetric.load_certificate(certificate),
|
||||||
# Amazon SNS uses SHA1withRSA.
|
signature,
|
||||||
# http://sns-public-resources.s3.amazonaws.com/SNS_Message_Signing_Release_Note_Jan_25_2011.pdf
|
string_to_sign,
|
||||||
public_key.reset_context(md='sha1')
|
"sha1"
|
||||||
public_key.verify_init()
|
)
|
||||||
|
return True
|
||||||
# Sign the string.
|
except oscrypto.errors.SignatureError:
|
||||||
public_key.verify_update(string_to_sign.encode())
|
raise ValidationError("Invalid signature")
|
||||||
|
|
||||||
# Verify the signature matches.
|
|
||||||
verification_result = public_key.verify_final(decoded_signature)
|
|
||||||
|
|
||||||
# M2Crypto uses EVP_VerifyFinal() from openssl as the underlying
|
|
||||||
# verification function. 1 indicates success, anything else is either
|
|
||||||
# a failure or an error.
|
|
||||||
if verification_result != 1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
Reference in New Issue
Block a user