diff --git a/app/job/rest.py b/app/job/rest.py index 0f80fab2d..831c3f2f4 100644 --- a/app/job/rest.py +++ b/app/job/rest.py @@ -1,7 +1,8 @@ +import re from zoneinfo import ZoneInfo import dateutil -from flask import Blueprint, current_app, jsonify, request +from flask import Blueprint, abort, current_app, jsonify, request from app import db from app.aws.s3 import ( @@ -44,8 +45,58 @@ job_blueprint = Blueprint("job", __name__, url_prefix="/service/", methods=["GET"]) def get_job_by_service_and_job_id(service_id, job_id): + check_suspicious_id(service_id, job_id) job = dao_get_job_by_service_id_and_job_id(service_id, job_id) statistics = dao_get_notification_outcomes_for_job(service_id, job_id) data = JobSchema(session=db.session).dump(job) @@ -59,6 +110,8 @@ def get_job_by_service_and_job_id(service_id, job_id): @job_blueprint.route("//cancel", methods=["POST"]) def cancel_job(service_id, job_id): + check_suspicious_id(service_id, job_id) + job = dao_get_future_scheduled_job_by_id_and_service_id(job_id, service_id) job.job_status = JobStatus.CANCELLED dao_update_job(job) @@ -68,6 +121,8 @@ def cancel_job(service_id, job_id): @job_blueprint.route("//notifications", methods=["GET"]) def get_all_notifications_for_service_job(service_id, job_id): + check_suspicious_id(service_id, job_id) + data = notifications_filter_schema.load(request.args) page = data["page"] if "page" in data else 1 page_size = ( @@ -129,6 +184,8 @@ def get_all_notifications_for_service_job(service_id, job_id): @job_blueprint.route("//recent_notifications", methods=["GET"]) def get_recent_notifications_for_service_job(service_id, job_id): + check_suspicious_id(service_id, job_id) + data = notifications_filter_schema.load(request.args) page = data["page"] if "page" in data else 1 page_size = ( @@ -194,6 +251,8 @@ def get_recent_notifications_for_service_job(service_id, job_id): @job_blueprint.route("//notification_count", methods=["GET"]) def get_notification_count_for_job_id(service_id, job_id): + check_suspicious_id(service_id, job_id) + dao_get_job_by_service_id_and_job_id(service_id, job_id) count = dao_get_notification_count_for_job_id(job_id=job_id) return jsonify(count=count), 200 @@ -201,6 +260,7 @@ def get_notification_count_for_job_id(service_id, job_id): @job_blueprint.route("", methods=["GET"]) def get_jobs_by_service(service_id): + check_suspicious_id(service_id) if request.args.get("limit_days"): try: limit_days = int(request.args["limit_days"]) @@ -234,6 +294,8 @@ def get_jobs_by_service(service_id): @job_blueprint.route("", methods=["POST"]) def create_job(service_id): + check_suspicious_id(service_id) + """Entry point from UI for one-off messages as well as CSV uploads.""" service = dao_fetch_service_by_id(service_id) if not service.active: @@ -253,6 +315,7 @@ def create_job(service_id): ) data["template"] = data.pop("template_id") + check_suspicious_id(data["template"]) template = dao_get_template_by_id(data["template"]) if data.get("valid") != "True": @@ -277,6 +340,7 @@ def create_job(service_id): dao_create_job(job) sender_id = data.get("sender_id") + check_suspicious_id(sender_id) # Kick off job in tasks.py if job.job_status == JobStatus.PENDING: process_job.apply_async( @@ -291,6 +355,8 @@ def create_job(service_id): @job_blueprint.route("/scheduled-job-stats", methods=["GET"]) def get_scheduled_job_stats(service_id): + check_suspicious_id(service_id) + count, soonest_scheduled_for = dao_get_scheduled_job_stats(service_id) return ( jsonify( diff --git a/tests/app/job/test_rest.py b/tests/app/job/test_rest.py index f65ce2458..dbba6d729 100644 --- a/tests/app/job/test_rest.py +++ b/tests/app/job/test_rest.py @@ -5,6 +5,7 @@ from unittest.mock import ANY from zoneinfo import ZoneInfo import pytest +import werkzeug from freezegun import freeze_time import app.celery.tasks @@ -16,6 +17,7 @@ from app.enums import ( NotificationType, TemplateType, ) +from app.job.rest import check_suspicious_id, is_suspicious_input, is_valid_id from app.utils import utc_now from tests import create_admin_authorization_header from tests.app.db import ( @@ -586,6 +588,31 @@ def test_get_all_notifications_for_job_returns_correct_format( assert resp["notifications"][0]["status"] == sample_notification_with_job.status +def test_is_valid_id(sample_job): + returnVal = is_valid_id(sample_job.service_id) + assert returnVal is True + + returnVal = is_valid_id("abc pgsleep(1)") + assert returnVal is False + + +def test_check_suspicious_id(sample_job): + # This should be good + check_suspicious_id(sample_job.id, sample_job.service_id) + + # This should be bad + with pytest.raises(werkzeug.exceptions.Forbidden): + check_suspicious_id(sample_job.id, "what is this???") + + +def test_is_suspicious_input(sample_job): + returnVal = is_suspicious_input(sample_job.id) + assert returnVal is False + + returnVal = is_suspicious_input("1 OR pg_sleep(1)") + assert returnVal is True + + def test_get_notification_count_for_job_id(admin_request, mocker, sample_job): mock_dao = mocker.patch( "app.job.rest.dao_get_notification_count_for_job_id", return_value=3