diff --git a/app/job/rest.py b/app/job/rest.py index 3e06aeb53..5189c6917 100644 --- a/app/job/rest.py +++ b/app/job/rest.py @@ -68,8 +68,7 @@ def check_suspicious_id(*args): @job_blueprint.route("/", methods=["GET"]) def get_job_by_service_and_job_id(service_id, job_id): - check_suspicious_id(service_id) - check_suspicious_id(job_id) + check_suspicious_id(service_id, job_id) job = dao_get_job_by_service_id_and_job_id(service_id, job_id) statistics = dao_get_notification_outcomes_for_job(service_id, job_id) data = JobSchema(session=db.session).dump(job) @@ -83,8 +82,7 @@ def get_job_by_service_and_job_id(service_id, job_id): @job_blueprint.route("//cancel", methods=["POST"]) def cancel_job(service_id, job_id): - check_suspicious_id(service_id) - check_suspicious_id(job_id) + check_suspicious_id(service_id, job_id) job = dao_get_future_scheduled_job_by_id_and_service_id(job_id, service_id) job.job_status = JobStatus.CANCELLED @@ -95,8 +93,7 @@ def cancel_job(service_id, job_id): @job_blueprint.route("//notifications", methods=["GET"]) def get_all_notifications_for_service_job(service_id, job_id): - check_suspicious_id(service_id) - check_suspicious_id(job_id) + check_suspicious_id(service_id, job_id) data = notifications_filter_schema.load(request.args) page = data["page"] if "page" in data else 1 @@ -159,8 +156,7 @@ def get_all_notifications_for_service_job(service_id, job_id): @job_blueprint.route("//recent_notifications", methods=["GET"]) def get_recent_notifications_for_service_job(service_id, job_id): - check_suspicious_id(service_id) - check_suspicious_id(job_id) + check_suspicious_id(service_id, job_id) data = notifications_filter_schema.load(request.args) page = data["page"] if "page" in data else 1 @@ -227,8 +223,7 @@ def get_recent_notifications_for_service_job(service_id, job_id): @job_blueprint.route("//notification_count", methods=["GET"]) def get_notification_count_for_job_id(service_id, job_id): - check_suspicious_id(service_id) - check_suspicious_id(job_id) + check_suspicious_id(service_id, job_id) dao_get_job_by_service_id_and_job_id(service_id, job_id) count = dao_get_notification_count_for_job_id(job_id=job_id) diff --git a/tests/app/job/test_rest.py b/tests/app/job/test_rest.py index e39b24b54..dbba6d729 100644 --- a/tests/app/job/test_rest.py +++ b/tests/app/job/test_rest.py @@ -5,6 +5,7 @@ from unittest.mock import ANY from zoneinfo import ZoneInfo import pytest +import werkzeug from freezegun import freeze_time import app.celery.tasks @@ -16,7 +17,7 @@ from app.enums import ( NotificationType, TemplateType, ) -from app.job.rest import is_suspicious_input, is_valid_id +from app.job.rest import check_suspicious_id, is_suspicious_input, is_valid_id from app.utils import utc_now from tests import create_admin_authorization_header from tests.app.db import ( @@ -595,6 +596,15 @@ def test_is_valid_id(sample_job): assert returnVal is False +def test_check_suspicious_id(sample_job): + # This should be good + check_suspicious_id(sample_job.id, sample_job.service_id) + + # This should be bad + with pytest.raises(werkzeug.exceptions.Forbidden): + check_suspicious_id(sample_job.id, "what is this???") + + def test_is_suspicious_input(sample_job): returnVal = is_suspicious_input(sample_job.id) assert returnVal is False