From 58a8b51f59736ad4c662ebdfbe3594dd884b397a Mon Sep 17 00:00:00 2001 From: Kenneth Kehl <@kkehl@flexion.us> Date: Thu, 26 Jun 2025 10:35:46 -0700 Subject: [PATCH] more input checking --- app/job/rest.py | 54 ++------------------------------- app/notifications/rest.py | 2 ++ app/organization/invite_rest.py | 8 +++++ app/organization/rest.py | 9 ++++++ app/provider_details/rest.py | 4 +++ app/service/rest.py | 54 ++++++++++++++++++++++++++++++++- app/service_invite/rest.py | 8 ++++- app/template/rest.py | 9 +++++- app/template_folder/rest.py | 7 +++++ app/template_statistics/rest.py | 4 ++- app/upload/rest.py | 3 +- app/user/rest.py | 23 +++++++++++++- app/utils.py | 52 ++++++++++++++++++++++++++++++- app/webauthn/rest.py | 4 +++ tests/app/job/test_rest.py | 27 ----------------- tests/app/test_utils.py | 32 +++++++++++++++++++ 16 files changed, 214 insertions(+), 86 deletions(-) diff --git a/app/job/rest.py b/app/job/rest.py index 831c3f2f4..6522ce93e 100644 --- a/app/job/rest.py +++ b/app/job/rest.py @@ -1,8 +1,7 @@ -import re from zoneinfo import ZoneInfo import dateutil -from flask import Blueprint, abort, current_app, jsonify, request +from flask import Blueprint, current_app, jsonify, request from app import db from app.aws.s3 import ( @@ -37,7 +36,7 @@ from app.schemas import ( notification_with_template_schema, notifications_filter_schema, ) -from app.utils import midnight_n_days_ago, pagination_links +from app.utils import check_suspicious_id, midnight_n_days_ago, pagination_links job_blueprint = Blueprint("job", __name__, url_prefix="/service//job") @@ -45,55 +44,6 @@ job_blueprint = Blueprint("job", __name__, url_prefix="/service/", methods=["GET"]) def get_job_by_service_and_job_id(service_id, job_id): check_suspicious_id(service_id, job_id) diff --git a/app/notifications/rest.py b/app/notifications/rest.py index 8e1f0c58a..55553b4bf 100644 --- a/app/notifications/rest.py +++ b/app/notifications/rest.py @@ -23,6 +23,7 @@ from app.schemas import ( ) from app.service.utils import service_allowed_to_send_to from app.utils import ( + check_suspicious_id, get_public_notify_type_text, get_template_instance, pagination_links, @@ -36,6 +37,7 @@ register_errors(notifications) @notifications.route("/notifications/", methods=["GET"]) def get_notification_by_id(notification_id): + check_suspicious_id(notification_id) notification = notifications_dao.get_notification_with_personalisation( str(authenticated_service.id), notification_id, key_type=None ) diff --git a/app/organization/invite_rest.py b/app/organization/invite_rest.py index caa803485..2e6a9ba2c 100644 --- a/app/organization/invite_rest.py +++ b/app/organization/invite_rest.py @@ -27,6 +27,7 @@ from app.organization.organization_schema import ( post_update_invited_org_user_status_schema, ) from app.schema_validation import validate +from app.utils import check_suspicious_id from notifications_utils.url_safe_token import check_token, generate_token organization_invite_blueprint = Blueprint("organization_invite", __name__) @@ -38,6 +39,7 @@ register_errors(organization_invite_blueprint) "/organization//invite", methods=["POST"] ) def invite_user_to_org(organization_id): + check_suspicious_id(organization_id) data = request.get_json() validate(data, post_create_invited_org_user_status_schema) @@ -98,6 +100,8 @@ def invite_user_to_org(organization_id): "/organization//invite", methods=["GET"] ) def get_invited_org_users_by_organization(organization_id): + + check_suspicious_id(organization_id) invited_org_users = get_invited_org_users_for_organization(organization_id) return jsonify(data=[x.serialize() for x in invited_org_users]), 200 @@ -106,6 +110,7 @@ def get_invited_org_users_by_organization(organization_id): "/organization//invite/", methods=["GET"] ) def get_invited_org_user_by_organization(organization_id, invited_org_user_id): + check_suspicious_id(organization_id, invited_org_user_id) invited_org_user = dao_get_invited_org_user(organization_id, invited_org_user_id) return jsonify(data=invited_org_user.serialize()), 200 @@ -115,6 +120,7 @@ def get_invited_org_user_by_organization(organization_id, invited_org_user_id): methods=["POST"], ) def update_org_invite_status(organization_id, invited_org_user_id): + check_suspicious_id(organization_id, invited_org_user_id) fetched = dao_get_invited_org_user( organization_id=organization_id, invited_org_user_id=invited_org_user_id ) @@ -129,6 +135,7 @@ def update_org_invite_status(organization_id, invited_org_user_id): def invited_org_user_url(invited_org_user_id, invite_link_host=None): + token = generate_token( str(invited_org_user_id), current_app.config["SECRET_KEY"], @@ -145,6 +152,7 @@ def invited_org_user_url(invited_org_user_id, invite_link_host=None): "/invite/organization/", methods=["GET"] ) def get_invited_org_user(invited_org_user_id): + check_suspicious_id(invited_org_user_id) invited_user = get_invited_org_user_by_id(invited_org_user_id) return jsonify(data=invited_user.serialize()), 200 diff --git a/app/organization/rest.py b/app/organization/rest.py index f3d887511..df20d3cf3 100644 --- a/app/organization/rest.py +++ b/app/organization/rest.py @@ -36,6 +36,7 @@ from app.organization.organization_schema import ( post_update_organization_schema, ) from app.schema_validation import validate +from app.utils import check_suspicious_id organization_blueprint = Blueprint("organization", __name__) register_errors(organization_blueprint) @@ -64,6 +65,7 @@ def get_organizations(): @organization_blueprint.route("/", methods=["GET"]) def get_organization_by_id(organization_id): + check_suspicious_id(organization_id) organization = dao_get_organization_by_id(organization_id) return jsonify(organization.serialize()) @@ -97,6 +99,7 @@ def create_organization(): @organization_blueprint.route("/", methods=["POST"]) def update_organization(organization_id): + check_suspicious_id(organization_id) data = request.get_json() validate(data, post_update_organization_schema) @@ -115,6 +118,7 @@ def update_organization(organization_id): @organization_blueprint.route("//service", methods=["POST"]) def link_service_to_organization(organization_id): + check_suspicious_id(organization_id) data = request.get_json() validate(data, post_link_service_to_organization_schema) service = dao_fetch_service_by_id(data["service_id"]) @@ -129,6 +133,7 @@ def link_service_to_organization(organization_id): @organization_blueprint.route("//services", methods=["GET"]) def get_organization_services(organization_id): + check_suspicious_id(organization_id) services = dao_get_organization_services(organization_id) sorted_services = sorted(services, key=lambda s: (-s.active, s.name)) return jsonify([s.serialize_for_org_dashboard() for s in sorted_services]) @@ -138,6 +143,7 @@ def get_organization_services(organization_id): "//services-with-usage", methods=["GET"] ) def get_organization_services_usage(organization_id): + check_suspicious_id(organization_id) try: year = int(request.args.get("year", "none")) except ValueError: @@ -154,6 +160,7 @@ def get_organization_services_usage(organization_id): "//users/", methods=["POST"] ) def add_user_to_organization(organization_id, user_id): + check_suspicious_id(organization_id, user_id) new_org_user = dao_add_user_to_organization(organization_id, user_id) return jsonify(data=new_org_user.serialize()) @@ -162,6 +169,7 @@ def add_user_to_organization(organization_id, user_id): "//users/", methods=["DELETE"] ) def remove_user_from_organization(organization_id, user_id): + check_suspicious_id(organization_id, user_id) organization = dao_get_organization_by_id(organization_id) user = get_user_by_id(user_id=user_id) @@ -176,6 +184,7 @@ def remove_user_from_organization(organization_id, user_id): @organization_blueprint.route("//users", methods=["GET"]) def get_organization_users(organization_id): + check_suspicious_id(organization_id) org_users = dao_get_users_for_organization(organization_id) return jsonify(data=[x.serialize() for x in org_users]) diff --git a/app/provider_details/rest.py b/app/provider_details/rest.py index 3a7e62332..db8d2b8db 100644 --- a/app/provider_details/rest.py +++ b/app/provider_details/rest.py @@ -9,6 +9,7 @@ from app.dao.provider_details_dao import ( from app.dao.users_dao import get_user_by_id from app.errors import InvalidRequest, register_errors from app.schemas import provider_details_history_schema, provider_details_schema +from app.utils import check_suspicious_id provider_details = Blueprint("provider_details", __name__) register_errors(provider_details) @@ -38,12 +39,14 @@ def get_providers(): @provider_details.route("/", methods=["GET"]) def get_provider_by_id(provider_details_id): + check_suspicious_id(provider_details_id) data = provider_details_schema.dump(get_provider_details_by_id(provider_details_id)) return jsonify(provider_details=data) @provider_details.route("//versions", methods=["GET"]) def get_provider_versions(provider_details_id): + check_suspicious_id(provider_details_id) versions = dao_get_provider_versions(provider_details_id) data = provider_details_history_schema.dump(versions, many=True) return jsonify(data=data) @@ -51,6 +54,7 @@ def get_provider_versions(provider_details_id): @provider_details.route("/", methods=["POST"]) def update_provider_details(provider_details_id): + check_suspicious_id(provider_details_id) valid_keys = {"priority", "created_by", "active"} req_json = request.get_json() diff --git a/app/service/rest.py b/app/service/rest.py index 49052fead..b204a940e 100644 --- a/app/service/rest.py +++ b/app/service/rest.py @@ -116,7 +116,7 @@ from app.service.service_senders_schema import ( ) from app.service.utils import get_guest_list_objects from app.user.users_schema import post_set_permissions_schema -from app.utils import get_prev_next_pagination_links, utc_now +from app.utils import check_suspicious_id, get_prev_next_pagination_links, utc_now service_blueprint = Blueprint("service", __name__) @@ -202,6 +202,7 @@ def get_live_services_data(): @service_blueprint.route("/", methods=["GET"]) def get_service_by_id(service_id): + check_suspicious_id(service_id) if request.args.get("detailed") == "True": data = get_detailed_service( service_id, today_only=request.args.get("today_only") == "True" @@ -217,6 +218,7 @@ def get_service_by_id(service_id): @service_blueprint.route("//statistics") def get_service_notification_statistics(service_id): + check_suspicious_id(service_id) return jsonify( data=get_service_statistics( service_id, @@ -228,12 +230,14 @@ def get_service_notification_statistics(service_id): @service_blueprint.route("//statistics//") def get_service_notification_statistics_by_day(service_id, start, days): + check_suspicious_id(service_id) return jsonify( data=get_service_statistics_for_specific_days(service_id, start, int(days)) ) def get_service_statistics_for_specific_days(service_id, start, days=1): + check_suspicious_id(service_id) # Calculate start and end date range end_date = datetime.strptime(start, "%Y-%m-%d") start_date = end_date - timedelta(days=days - 1) @@ -264,6 +268,7 @@ def get_service_statistics_for_specific_days(service_id, start, days=1): def get_service_notification_statistics_by_day_by_user( service_id, user_id, start, days ): + check_suspicious_id(service_id, user_id) return jsonify( data=get_service_statistics_for_specific_days_by_user( service_id, user_id, start, int(days) @@ -323,6 +328,7 @@ def create_service(): @service_blueprint.route("/", methods=["POST"]) def update_service(service_id): + check_suspicious_id(service_id) req_json = request.get_json() fetched_service = dao_fetch_service_by_id(service_id) service_going_live = fetched_service.restricted and not req_json.get( @@ -364,6 +370,8 @@ def update_service(service_id): @service_blueprint.route("//api-key", methods=["POST"]) def create_api_key(service_id=None): + if service_id: + check_suspicious_id(service_id) fetched_service = dao_fetch_service_by_id(service_id=service_id) valid_api_key = api_key_schema.load(request.get_json(), session=db.session) valid_api_key.service = fetched_service @@ -376,6 +384,7 @@ def create_api_key(service_id=None): "//api-key/revoke/", methods=["POST"] ) def revoke_api_key(service_id, api_key_id): + check_suspicious_id(service_id, api_key_id) expire_api_key(service_id=service_id, api_key_id=api_key_id) return jsonify(), 202 @@ -383,6 +392,10 @@ def revoke_api_key(service_id, api_key_id): @service_blueprint.route("//api-keys", methods=["GET"]) @service_blueprint.route("//api-keys/", methods=["GET"]) def get_api_keys(service_id, key_id=None): + if key_id: + check_suspicious_id(service_id, key_id) + else: + check_suspicious_id(service_id) dao_fetch_service_by_id(service_id=service_id) try: @@ -399,12 +412,14 @@ def get_api_keys(service_id, key_id=None): @service_blueprint.route("//users", methods=["GET"]) def get_users_for_service(service_id): + check_suspicious_id(service_id) fetched = dao_fetch_service_by_id(service_id) return jsonify(data=[x.serialize() for x in fetched.users]) @service_blueprint.route("//users/", methods=["POST"]) def add_user_to_service(service_id, user_id): + check_suspicious_id(service_id, user_id) service = dao_fetch_service_by_id(service_id) user = get_user_by_id(user_id=user_id) if user in service.users: @@ -428,6 +443,7 @@ def add_user_to_service(service_id, user_id): @service_blueprint.route("//users/", methods=["DELETE"]) def remove_user_from_service(service_id, user_id): + check_suspicious_id(service_id, user_id) service = dao_fetch_service_by_id(service_id) user = get_user_by_id(user_id=user_id) if user not in service.users: @@ -447,6 +463,7 @@ def remove_user_from_service(service_id, user_id): # tables. This is so product owner can pass stories as done @service_blueprint.route("//history", methods=["GET"]) def get_service_history(service_id): + check_suspicious_id(service_id) from app.models import ApiKey, Service, TemplateHistory from app.schemas import ( api_key_history_schema, @@ -496,6 +513,7 @@ def get_service_history(service_id): @service_blueprint.route("//notifications", methods=["GET", "POST"]) def get_all_notifications_for_service(service_id): + check_suspicious_id(service_id) current_app.logger.debug("enter get_all_notifications_for_service") if request.method == "GET": data = notifications_filter_schema.load(request.args) @@ -625,6 +643,7 @@ def get_all_notifications_for_service(service_id): "//notifications/", methods=["GET"] ) def get_notification_for_service(service_id, notification_id): + check_suspicious_id(service_id, notification_id) notification = notifications_dao.get_notification_with_personalisation( service_id, notification_id, @@ -640,6 +659,7 @@ def get_notification_for_service(service_id, notification_id): @service_blueprint.route("//notifications/monthly", methods=["GET"]) def get_monthly_notification_stats(service_id): + check_suspicious_id(service_id) # check service_id validity dao_fetch_service_by_id(service_id) @@ -672,6 +692,7 @@ def get_monthly_notification_stats(service_id): "//notifications//monthly", methods=["GET"] ) def get_monthly_notification_stats_by_user(service_id, user_id): + check_suspicious_id(service_id, user_id) # check service_id validity dao_fetch_service_by_id(service_id) # user = get_user_by_id(user_id=user_id) @@ -705,6 +726,7 @@ def get_monthly_notification_stats_by_user(service_id, user_id): "//notifications//month", methods=["GET"] ) def get_single_month_notification_stats_by_user(service_id, user_id): + check_suspicious_id(service_id, user_id) # check service_id validity dao_fetch_service_by_id(service_id) @@ -734,6 +756,7 @@ def get_single_month_notification_stats_by_user(service_id, user_id): @service_blueprint.route("//notifications/month", methods=["GET"]) def get_single_month_notification_stats_for_service(service_id): + check_suspicious_id(service_id) # check service_id validity dao_fetch_service_by_id(service_id) @@ -757,6 +780,7 @@ def get_single_month_notification_stats_for_service(service_id): def get_detailed_service(service_id, today_only=False): + check_suspicious_id(service_id) service = dao_fetch_service_by_id(service_id) service.statistics = get_service_statistics(service_id, today_only) @@ -764,6 +788,7 @@ def get_detailed_service(service_id, today_only=False): def get_service_statistics(service_id, today_only, limit_days=7): + check_suspicious_id(service_id) # today_only flag is used by the send page to work out if the service will exceed their daily usage by sending a job if today_only: stats = dao_fetch_todays_stats_for_service(service_id) @@ -810,6 +835,7 @@ def get_detailed_services( @service_blueprint.route("//guest-list", methods=["GET"]) def get_guest_list(service_id): + check_suspicious_id(service_id) from app.enums import RecipientType service = dao_fetch_service_by_id(service_id) @@ -834,6 +860,7 @@ def get_guest_list(service_id): @service_blueprint.route("//guest-list", methods=["PUT"]) def update_guest_list(service_id): + check_suspicious_id(service_id) # doesn't commit so if there are any errors, we preserve old values in db dao_remove_service_guest_list(service_id) try: @@ -850,6 +877,7 @@ def update_guest_list(service_id): @service_blueprint.route("//archive", methods=["POST"]) def archive_service(service_id): + check_suspicious_id(service_id) """ When a service is archived the service is made inactive, templates are archived and api keys are revoked. There is no coming back from this operation. @@ -866,6 +894,7 @@ def archive_service(service_id): @service_blueprint.route("//suspend", methods=["POST"]) def suspend_service(service_id): + check_suspicious_id(service_id) """ Suspending a service will mark the service as inactive and revoke API keys. :param service_id: @@ -881,6 +910,7 @@ def suspend_service(service_id): @service_blueprint.route("//resume", methods=["POST"]) def resume_service(service_id): + check_suspicious_id(service_id) """ Resuming a service that has been suspended will mark the service as active. The service will need to re-create API keys @@ -899,6 +929,7 @@ def resume_service(service_id): "//notifications/templates_usage/monthly", methods=["GET"] ) def get_monthly_template_usage(service_id): + check_suspicious_id(service_id) try: start_date, end_date = get_calendar_year(int(request.args.get("year", "NaN"))) data = fetch_monthly_template_usage_for_service( @@ -924,12 +955,14 @@ def get_monthly_template_usage(service_id): @service_blueprint.route("//send-notification", methods=["POST"]) def create_one_off_notification(service_id): + check_suspicious_id(service_id) resp = send_one_off_notification(service_id, request.get_json()) return jsonify(resp), 201 @service_blueprint.route("//email-reply-to", methods=["GET"]) def get_email_reply_to_addresses(service_id): + check_suspicious_id(service_id) result = dao_get_reply_to_by_service_id(service_id) return jsonify([i.serialize() for i in result]), 200 @@ -938,12 +971,14 @@ def get_email_reply_to_addresses(service_id): "//email-reply-to/", methods=["GET"] ) def get_email_reply_to_address(service_id, reply_to_id): + check_suspicious_id(service_id, reply_to_id) result = dao_get_reply_to_by_id(service_id=service_id, reply_to_id=reply_to_id) return jsonify(result.serialize()), 200 @service_blueprint.route("//email-reply-to/verify", methods=["POST"]) def verify_reply_to_email_address(service_id): + check_suspicious_id(service_id) email_address = email_data_request_schema.load(request.get_json()) check_if_reply_to_address_already_in_use(service_id, email_address["email"]) @@ -970,6 +1005,7 @@ def verify_reply_to_email_address(service_id): @service_blueprint.route("//email-reply-to", methods=["POST"]) def add_service_reply_to_email_address(service_id): + check_suspicious_id(service_id) # validate the service exists, throws ResultNotFound exception. dao_fetch_service_by_id(service_id) form = validate(request.get_json(), add_service_email_reply_to_request) @@ -986,6 +1022,7 @@ def add_service_reply_to_email_address(service_id): "//email-reply-to/", methods=["POST"] ) def update_service_reply_to_email_address(service_id, reply_to_email_id): + check_suspicious_id(service_id, reply_to_email_id) # validate the service exists, throws ResultNotFound exception. dao_fetch_service_by_id(service_id) form = validate(request.get_json(), add_service_email_reply_to_request) @@ -1003,6 +1040,7 @@ def update_service_reply_to_email_address(service_id, reply_to_email_id): methods=["POST"], ) def delete_service_reply_to_email_address(service_id, reply_to_email_id): + check_suspicious_id(service_id, reply_to_email_id) archived_reply_to = archive_reply_to_email_address(service_id, reply_to_email_id) return jsonify(data=archived_reply_to.serialize()), 200 @@ -1010,6 +1048,7 @@ def delete_service_reply_to_email_address(service_id, reply_to_email_id): @service_blueprint.route("//sms-sender", methods=["POST"]) def add_service_sms_sender(service_id): + check_suspicious_id(service_id) dao_fetch_service_by_id(service_id) form = validate(request.get_json(), add_service_sms_sender_request) inbound_number_id = form.get("inbound_number_id", None) @@ -1046,6 +1085,7 @@ def add_service_sms_sender(service_id): "//sms-sender/", methods=["POST"] ) def update_service_sms_sender(service_id, sms_sender_id): + check_suspicious_id(service_id, sms_sender_id) form = validate(request.get_json(), add_service_sms_sender_request) sms_sender_to_update = dao_get_service_sms_senders_by_id( @@ -1073,6 +1113,7 @@ def update_service_sms_sender(service_id, sms_sender_id): "//sms-sender//archive", methods=["POST"] ) def delete_service_sms_sender(service_id, sms_sender_id): + check_suspicious_id(service_id, sms_sender_id) sms_sender = archive_sms_sender(service_id, sms_sender_id) return jsonify(data=sms_sender.serialize()), 200 @@ -1082,6 +1123,8 @@ def delete_service_sms_sender(service_id, sms_sender_id): "//sms-sender/", methods=["GET"] ) def get_service_sms_sender_by_id(service_id, sms_sender_id): + check_suspicious_id(service_id, sms_sender_id) + sms_sender = dao_get_service_sms_senders_by_id( service_id=service_id, service_sms_sender_id=sms_sender_id ) @@ -1090,18 +1133,22 @@ def get_service_sms_sender_by_id(service_id, sms_sender_id): @service_blueprint.route("//sms-sender", methods=["GET"]) def get_service_sms_senders_for_service(service_id): + check_suspicious_id(service_id) + sms_senders = dao_get_sms_senders_by_service_id(service_id=service_id) return jsonify([sms_sender.serialize() for sms_sender in sms_senders]), 200 @service_blueprint.route("//organization", methods=["GET"]) def get_organization_for_service(service_id): + check_suspicious_id(service_id) organization = dao_get_organization_by_service_id(service_id=service_id) return jsonify(organization.serialize() if organization else {}), 200 @service_blueprint.route("//data-retention", methods=["GET"]) def get_data_retention_for_service(service_id): + check_suspicious_id(service_id) data_retention_list = fetch_service_data_retention(service_id) return ( jsonify([data_retention.serialize() for data_retention in data_retention_list]), @@ -1114,6 +1161,7 @@ def get_data_retention_for_service(service_id): methods=["GET"], ) def get_data_retention_for_service_notification_type(service_id, notification_type): + check_suspicious_id(service_id) data_retention = fetch_service_data_retention_by_notification_type( service_id, notification_type ) @@ -1124,12 +1172,14 @@ def get_data_retention_for_service_notification_type(service_id, notification_ty "//data-retention/", methods=["GET"] ) def get_data_retention_for_service_by_id(service_id, data_retention_id): + check_suspicious_id(service_id, data_retention_id) data_retention = fetch_service_data_retention_by_id(service_id, data_retention_id) return jsonify(data_retention.serialize() if data_retention else {}), 200 @service_blueprint.route("//data-retention", methods=["POST"]) def create_service_data_retention(service_id): + check_suspicious_id(service_id) form = validate(request.get_json(), add_service_data_retention_request) try: new_data_retention = insert_service_data_retention( @@ -1152,6 +1202,7 @@ def create_service_data_retention(service_id): "//data-retention/", methods=["POST"] ) def modify_service_data_retention(service_id, data_retention_id): + check_suspicious_id(service_id, data_retention_id) form = validate(request.get_json(), update_service_data_retention_request) update_count = update_service_data_retention( @@ -1252,5 +1303,6 @@ def check_if_reply_to_address_already_in_use(service_id, email_address): @service_blueprint.route("//notification-count", methods=["GET"]) def get_notification_count_for_service_id(service_id): + check_suspicious_id(service_id) count = dao_get_notification_count_for_service(service_id=service_id) return jsonify(count=count), 200 diff --git a/app/service_invite/rest.py b/app/service_invite/rest.py index 9dc862ae1..0e36c00f1 100644 --- a/app/service_invite/rest.py +++ b/app/service_invite/rest.py @@ -25,7 +25,7 @@ from app.notifications.process_notifications import ( send_notification_to_queue, ) from app.schemas import InvitedUserSchema -from app.utils import utc_now +from app.utils import check_suspicious_id, utc_now from notifications_utils.url_safe_token import check_token, generate_token service_invite = Blueprint("service_invite", __name__) @@ -121,6 +121,7 @@ def create_invited_user(service_id): @service_invite.route("/service//invite/expired", methods=["GET"]) def get_expired_invited_users_by_service(service_id): + check_suspicious_id(service_id) expired_invited_users = get_expired_invited_users_for_service(service_id) return ( jsonify( @@ -134,6 +135,7 @@ def get_expired_invited_users_by_service(service_id): @service_invite.route("/service//invite", methods=["GET"]) def get_invited_users_by_service(service_id): + check_suspicious_id(service_id) invited_users = get_invited_users_for_service(service_id) return ( jsonify( @@ -145,6 +147,7 @@ def get_invited_users_by_service(service_id): @service_invite.route("/service//invite/", methods=["GET"]) def get_invited_user_by_service(service_id, invited_user_id): + check_suspicious_id(service_id, invited_user_id) invited_user = get_invited_user_by_service_and_id(service_id, invited_user_id) return jsonify(data=InvitedUserSchema(session=db.session).dump(invited_user)), 200 @@ -153,6 +156,7 @@ def get_invited_user_by_service(service_id, invited_user_id): "/service//invite/", methods=["POST"] ) def update_invited_user(service_id, invited_user_id): + check_suspicious_id(service_id, invited_user_id) fetched = get_invited_user_by_service_and_id( service_id=service_id, invited_user_id=invited_user_id ) @@ -168,6 +172,7 @@ def update_invited_user(service_id, invited_user_id): "/service//invite//resend", methods=["POST"] ) def resend_service_invite(service_id, invited_user_id): + check_suspicious_id(service_id, invited_user_id) """Resend an expired invite. This resets the invited user's created date and status to make it a new invite, and @@ -228,6 +233,7 @@ def invited_user_url(invited_user_id, invite_link_host=None): @service_invite.route("/invite/service/", methods=["GET"]) def get_invited_user(invited_user_id): + check_suspicious_id(invited_user_id) invited_user = get_invited_user_by_id(invited_user_id) return jsonify(data=InvitedUserSchema(session=db.session).dump(invited_user)), 200 diff --git a/app/template/rest.py b/app/template/rest.py index c198bc618..65faecef9 100644 --- a/app/template/rest.py +++ b/app/template/rest.py @@ -26,7 +26,7 @@ from app.template.template_schemas import ( post_create_template_schema, post_update_template_schema, ) -from app.utils import get_public_notify_type_text +from app.utils import check_suspicious_id, get_public_notify_type_text from notifications_utils import SMS_CHAR_COUNT_LIMIT from notifications_utils.template import SMSMessageTemplate @@ -61,6 +61,7 @@ def validate_parent_folder(template_json): @template_blueprint.route("", methods=["POST"]) def create_template(service_id): + check_suspicious_id(service_id) fetched_service = dao_fetch_service_by_id(service_id=service_id) # permissions needs to be placed here otherwise marshmallow will interfere with versioning permissions = [p.permission for p in fetched_service.permissions] @@ -96,6 +97,7 @@ def create_template(service_id): @template_blueprint.route("/", methods=["POST"]) def update_template(service_id, template_id): + check_suspicious_id(service_id, template_id) fetched_template = dao_get_template_by_id_and_service_id( template_id=template_id, service_id=service_id ) @@ -146,6 +148,7 @@ def update_template(service_id, template_id): @template_blueprint.route("", methods=["GET"]) def get_all_templates_for_service(service_id): + check_suspicious_id(service_id) templates = dao_get_all_templates_for_service(service_id=service_id) if str(request.args.get("detailed", True)) == "True": data = template_schema.dump(templates, many=True) @@ -156,6 +159,7 @@ def get_all_templates_for_service(service_id): @template_blueprint.route("/", methods=["GET"]) def get_template_by_id_and_service_id(service_id, template_id): + check_suspicious_id(service_id, template_id) fetched_template = dao_get_template_by_id_and_service_id( template_id=template_id, service_id=service_id ) @@ -165,6 +169,7 @@ def get_template_by_id_and_service_id(service_id, template_id): @template_blueprint.route("//preview", methods=["GET"]) def preview_template_by_id_and_service_id(service_id, template_id): + check_suspicious_id(service_id, template_id) fetched_template = dao_get_template_by_id_and_service_id( template_id=template_id, service_id=service_id ) @@ -193,6 +198,7 @@ def preview_template_by_id_and_service_id(service_id, template_id): @template_blueprint.route("//version/") def get_template_version(service_id, template_id, version): + check_suspicious_id(service_id, template_id) data = template_history_schema.dump( dao_get_template_by_id_and_service_id( template_id=template_id, service_id=service_id, version=version @@ -203,6 +209,7 @@ def get_template_version(service_id, template_id, version): @template_blueprint.route("//versions") def get_template_versions(service_id, template_id): + check_suspicious_id(service_id, template_id) data = template_history_schema.dump( dao_get_template_versions(service_id=service_id, template_id=template_id), many=True, diff --git a/app/template_folder/rest.py b/app/template_folder/rest.py index 4f2073712..846cb81ab 100644 --- a/app/template_folder/rest.py +++ b/app/template_folder/rest.py @@ -20,6 +20,7 @@ from app.template_folder.template_folder_schema import ( post_move_template_folder_schema, post_update_template_folder_schema, ) +from app.utils import check_suspicious_id template_folder_blueprint = Blueprint( "template_folder", __name__, url_prefix="/service//template-folder" @@ -37,6 +38,7 @@ def handle_integrity_error(exc): @template_folder_blueprint.route("", methods=["GET"]) def get_template_folders_for_service(service_id): + check_suspicious_id(service_id) service = dao_fetch_service_by_id(service_id) template_folders = [o.serialize() for o in service.all_template_folders] @@ -45,6 +47,7 @@ def get_template_folders_for_service(service_id): @template_folder_blueprint.route("", methods=["POST"]) def create_template_folder(service_id): + check_suspicious_id(service_id) data = request.get_json() validate(data, post_create_template_folder_schema) @@ -72,6 +75,7 @@ def create_template_folder(service_id): @template_folder_blueprint.route("/", methods=["POST"]) def update_template_folder(service_id, template_folder_id): + check_suspicious_id(service_id, template_folder_id) data = request.get_json() validate(data, post_update_template_folder_schema) @@ -93,6 +97,7 @@ def update_template_folder(service_id, template_folder_id): @template_folder_blueprint.route("/", methods=["DELETE"]) def delete_template_folder(service_id, template_folder_id): + check_suspicious_id(service_id, template_folder_id) template_folder = dao_get_template_folder_by_id_and_service_id( template_folder_id, service_id ) @@ -112,6 +117,8 @@ def delete_template_folder(service_id, template_folder_id): ) @autocommit def move_to_template_folder(service_id, target_template_folder_id=None): + check_suspicious_id(service_id, target_template_folder_id) + data = request.get_json() validate(data, post_move_template_folder_schema) diff --git a/app/template_statistics/rest.py b/app/template_statistics/rest.py index cf4482caa..314373e73 100644 --- a/app/template_statistics/rest.py +++ b/app/template_statistics/rest.py @@ -6,7 +6,7 @@ from app.dao.fact_notification_status_dao import ( from app.dao.notifications_dao import dao_get_last_date_template_was_used from app.dao.templates_dao import dao_get_template_by_id_and_service_id from app.errors import InvalidRequest, register_errors -from app.utils import DATETIME_FORMAT +from app.utils import DATETIME_FORMAT, check_suspicious_id template_statistics = Blueprint( "template_statistics", @@ -19,6 +19,7 @@ register_errors(template_statistics) @template_statistics.route("") def get_template_statistics_for_service_by_day(service_id): + check_suspicious_id(service_id) whole_days = request.args.get("whole_days", request.args.get("limit_days", "")) try: whole_days = int(whole_days) @@ -56,6 +57,7 @@ def get_template_statistics_for_service_by_day(service_id): @template_statistics.route("/last-used/") def get_last_used_datetime_for_template(service_id, template_id): + check_suspicious_id(service_id, template_id) # Check the template and service exist dao_get_template_by_id_and_service_id(template_id, service_id) diff --git a/app/upload/rest.py b/app/upload/rest.py index 7ac3f07ac..a11b0bf0a 100644 --- a/app/upload/rest.py +++ b/app/upload/rest.py @@ -4,7 +4,7 @@ from app.dao.fact_notification_status_dao import fetch_notification_statuses_for from app.dao.jobs_dao import dao_get_notification_outcomes_for_job from app.dao.uploads_dao import dao_get_uploads_by_service_id from app.errors import register_errors -from app.utils import midnight_n_days_ago, pagination_links +from app.utils import check_suspicious_id, midnight_n_days_ago, pagination_links upload_blueprint = Blueprint( "upload", __name__, url_prefix="/service//upload" @@ -15,6 +15,7 @@ register_errors(upload_blueprint) @upload_blueprint.route("", methods=["GET"]) def get_uploads_by_service(service_id): + check_suspicious_id(service_id) return jsonify( **get_paginated_uploads( service_id, diff --git a/app/user/rest.py b/app/user/rest.py index da86521ff..a1eb93eaa 100644 --- a/app/user/rest.py +++ b/app/user/rest.py @@ -53,7 +53,13 @@ from app.user.users_schema import ( post_verify_code_schema, post_verify_webauthn_schema, ) -from app.utils import debug_not_production, hilite, url_with_token, utc_now +from app.utils import ( + check_suspicious_id, + debug_not_production, + hilite, + url_with_token, + utc_now, +) from notifications_utils.recipients import is_us_phone_number, use_numeric_sender user_blueprint = Blueprint("user", __name__) @@ -95,6 +101,7 @@ def create_user(): @user_blueprint.route("/", methods=["POST"]) def update_user_attribute(user_id): + check_suspicious_id(user_id) user_to_update = get_user_by_id(user_id=user_id) req_json = request.get_json() if "updated_by" in req_json: @@ -159,6 +166,7 @@ def get_sms_reply_to_for_notify_service(recipient, template): @user_blueprint.route("//archive", methods=["POST"]) def archive_user(user_id): + check_suspicious_id(user_id) user = get_user_by_id(user_id) dao_archive_user(user) @@ -167,6 +175,7 @@ def archive_user(user_id): @user_blueprint.route("//activate", methods=["POST"]) def activate_user(user_id): + check_suspicious_id(user_id) user = get_user_by_id(user_id=user_id) if user.state == "active": raise InvalidRequest("User already active", status_code=400) @@ -178,6 +187,7 @@ def activate_user(user_id): @user_blueprint.route("//deactivate", methods=["POST"]) def deactivate_user(user_id): + check_suspicious_id(user_id) user = get_user_by_id(user_id=user_id) if user.state == "pending": raise InvalidRequest("User already inactive", status_code=400) @@ -189,6 +199,7 @@ def deactivate_user(user_id): @user_blueprint.route("//reset-failed-login-count", methods=["POST"]) def user_reset_failed_login_count(user_id): + check_suspicious_id(user_id) user_to_update = get_user_by_id(user_id=user_id) reset_failed_login_count(user_to_update) return jsonify(data=user_to_update.serialize()), 200 @@ -196,6 +207,7 @@ def user_reset_failed_login_count(user_id): @user_blueprint.route("//verify/password", methods=["POST"]) def verify_user_password(user_id): + check_suspicious_id(user_id) user_to_verify = get_user_by_id(user_id=user_id) try: @@ -217,6 +229,7 @@ def verify_user_password(user_id): @user_blueprint.route("//verify/code", methods=["POST"]) def verify_user_code(user_id): + check_suspicious_id(user_id) data = request.get_json() validate(data, post_verify_code_schema) @@ -251,6 +264,7 @@ def verify_user_code(user_id): @user_blueprint.route("//complete/webauthn-login", methods=["POST"]) @user_blueprint.route("//verify/webauthn-login", methods=["POST"]) def complete_login_after_webauthn_authentication_attempt(user_id): + check_suspicious_id(user_id) """ complete login after a webauthn authentication. There's nothing webauthn specific in this code but the sms/email flows do this as part of `verify_user_code` above and this is the equivalent spot in the @@ -283,6 +297,7 @@ def complete_login_after_webauthn_authentication_attempt(user_id): @user_blueprint.route("//-code", methods=["POST"]) def send_user_2fa_code(user_id, code_type): + check_suspicious_id(user_id) user_to_send_to = get_user_by_id(user_id=user_id) if count_user_verify_codes(user_to_send_to) >= current_app.config.get( @@ -386,6 +401,7 @@ def create_2fa_code( @user_blueprint.route("//change-email-verification", methods=["POST"]) def send_user_confirm_new_email(user_id): + check_suspicious_id(user_id) user_to_send_to = get_user_by_id(user_id=user_id) email = email_data_request_schema.load(request.get_json()) @@ -425,6 +441,7 @@ def send_user_confirm_new_email(user_id): @user_blueprint.route("//email-verification", methods=["POST"]) def send_new_user_email_verification(user_id): + check_suspicious_id(user_id) current_app.logger.info("Sending email verification for user {}".format(user_id)) request_json = request.get_json() @@ -479,6 +496,7 @@ def send_new_user_email_verification(user_id): @user_blueprint.route("//email-already-registered", methods=["POST"]) def send_already_registered_email(user_id): + check_suspicious_id(user_id) current_app.logger.info("Email already registered for user {}".format(user_id)) to = email_data_request_schema.load(request.get_json()) @@ -528,6 +546,7 @@ def send_already_registered_email(user_id): @user_blueprint.route("/", methods=["GET"]) @user_blueprint.route("", methods=["GET"]) def get_user(user_id=None): + check_suspicious_id(user_id) users = get_user_by_id(user_id=user_id) result = ( [x.serialize() for x in users] if isinstance(users, list) else users.serialize() @@ -539,6 +558,7 @@ def get_user(user_id=None): "//service//permission", methods=["POST"] ) def set_permissions(user_id, service_id): + check_suspicious_id(user_id, service_id) # TODO fix security hole, how do we verify that the user # who is making this request has permission to make the request. service_user = dao_get_service_user(user_id, service_id) @@ -651,6 +671,7 @@ def report_all_users(): @user_blueprint.route("//organizations-and-services", methods=["GET"]) def get_organizations_and_services_for_user(user_id): + check_suspicious_id(user_id) user = get_user_and_accounts(user_id) data = get_orgs_and_services(user) return jsonify(data) diff --git a/app/utils.py b/app/utils.py index 07c2571b1..479ef6503 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,7 +1,8 @@ import os +import re from datetime import datetime, timedelta, timezone -from flask import current_app, url_for +from flask import abort, current_app, url_for from sqlalchemy import func from notifications_utils.template import HTMLEmailTemplate, SMSMessageTemplate @@ -158,3 +159,52 @@ def emit_job_update_summary(job): }, room=f"job-{job.id}", ) + + +def is_suspicious_input(input_str): + if not isinstance(input_str, str): + return False + + pattern = re.compile( + r""" + (?i) # case insensite + \b # word boundary + ( # start of group for SQL keywords + OR # match SQL keyword OR + |AND + |UNION + |SELECT + |DROP + |INSERT + |UPDATE + |DELETE + |EXEC + |TRUNCATE + |CREATE + |ALTER + |-- # match SQL single-line comment + |/\* # match SQL multi-line comment + |\bpg_sleep\b # Match PostgreSQL 'pg_sleep' function + + |\bsleep\b # Match SQL Server 'sleep' function + ) # End SQL keywords and function group + | # OR operator to include an alternate pattern + [';]{2,} # Match two or more consecutive single quotes or semi-colons + """, + re.VERBOSE, + ) + return bool(re.search(pattern, input_str)) + + +def is_valid_id(id): + if not isinstance(id, str): + return True + return bool(re.match(r"^[a-zA-Z0-9_-]{1,50}$", id)) + + +def check_suspicious_id(*args): + for id in args: + if not is_valid_id(id): + abort(403) + if is_suspicious_input(id): + abort(403) diff --git a/app/webauthn/rest.py b/app/webauthn/rest.py index 1dba333e7..e5827b65a 100644 --- a/app/webauthn/rest.py +++ b/app/webauthn/rest.py @@ -9,6 +9,7 @@ from app.dao.webauthn_credential_dao import ( ) from app.errors import InvalidRequest, register_errors from app.schema_validation import validate +from app.utils import check_suspicious_id from app.webauthn.webauthn_schema import ( post_create_webauthn_credential_schema, post_update_webauthn_credential_schema, @@ -28,6 +29,7 @@ def get_webauthn_credentials(user_id): @webauthn_blueprint.route("", methods=["POST"]) def create_webauthn_credential(user_id): + check_suspicious_id(user_id) data = request.get_json() validate(data, post_create_webauthn_credential_schema) webauthn_credential = dao_create_webauthn_credential( @@ -42,6 +44,7 @@ def create_webauthn_credential(user_id): @webauthn_blueprint.route("/", methods=["POST"]) def update_webauthn_credential(user_id, webauthn_credential_id): + check_suspicious_id(user_id, webauthn_credential_id) data = request.get_json() validate(data, post_update_webauthn_credential_schema) @@ -56,6 +59,7 @@ def update_webauthn_credential(user_id, webauthn_credential_id): @webauthn_blueprint.route("/", methods=["DELETE"]) def delete_webauthn_credential(user_id, webauthn_credential_id): + check_suspicious_id(user_id, webauthn_credential_id) webauthn_credential = dao_get_webauthn_credential_by_user_and_id( user_id, webauthn_credential_id ) diff --git a/tests/app/job/test_rest.py b/tests/app/job/test_rest.py index dbba6d729..f65ce2458 100644 --- a/tests/app/job/test_rest.py +++ b/tests/app/job/test_rest.py @@ -5,7 +5,6 @@ from unittest.mock import ANY from zoneinfo import ZoneInfo import pytest -import werkzeug from freezegun import freeze_time import app.celery.tasks @@ -17,7 +16,6 @@ from app.enums import ( NotificationType, TemplateType, ) -from app.job.rest import check_suspicious_id, is_suspicious_input, is_valid_id from app.utils import utc_now from tests import create_admin_authorization_header from tests.app.db import ( @@ -588,31 +586,6 @@ def test_get_all_notifications_for_job_returns_correct_format( assert resp["notifications"][0]["status"] == sample_notification_with_job.status -def test_is_valid_id(sample_job): - returnVal = is_valid_id(sample_job.service_id) - assert returnVal is True - - returnVal = is_valid_id("abc pgsleep(1)") - assert returnVal is False - - -def test_check_suspicious_id(sample_job): - # This should be good - check_suspicious_id(sample_job.id, sample_job.service_id) - - # This should be bad - with pytest.raises(werkzeug.exceptions.Forbidden): - check_suspicious_id(sample_job.id, "what is this???") - - -def test_is_suspicious_input(sample_job): - returnVal = is_suspicious_input(sample_job.id) - assert returnVal is False - - returnVal = is_suspicious_input("1 OR pg_sleep(1)") - assert returnVal is True - - def test_get_notification_count_for_job_id(admin_request, mocker, sample_job): mock_dao = mocker.patch( "app.job.rest.dao_get_notification_count_for_job_id", return_value=3 diff --git a/tests/app/test_utils.py b/tests/app/test_utils.py index 024c024a8..9969f1b56 100644 --- a/tests/app/test_utils.py +++ b/tests/app/test_utils.py @@ -1,13 +1,17 @@ from datetime import date, datetime import pytest +import werkzeug from freezegun import freeze_time from app.enums import ServicePermissionType, TemplateType from app.utils import ( + check_suspicious_id, get_midnight_in_utc, get_public_notify_type_text, get_template_instance, + is_suspicious_input, + is_valid_id, midnight_n_days_ago, ) from notifications_utils.template import HTMLEmailTemplate, SMSMessageTemplate @@ -141,3 +145,31 @@ def test_get_template_instance_comprehensive(template_type, values): assert isinstance(result, SMSMessageTemplate) else: assert isinstance(result, HTMLEmailTemplate) + + +def test_is_valid_id(sample_job): + returnVal = is_valid_id(sample_job.service_id) + assert returnVal is True + + returnVal = is_valid_id("abc pgsleep(1)") + assert returnVal is False + + +def test_check_suspicious_id(sample_job): + # This should be good + check_suspicious_id(sample_job.id, sample_job.service_id) + + # This should be bad + with pytest.raises(werkzeug.exceptions.Forbidden): + check_suspicious_id(sample_job.id, "what is this???") + + # This should be good + check_suspicious_id(sample_job.id, None) + + +def test_is_suspicious_input(sample_job): + returnVal = is_suspicious_input(sample_job.id) + assert returnVal is False + + returnVal = is_suspicious_input("1 OR pg_sleep(1)") + assert returnVal is True