diff --git a/tests/app/job/test_rest.py b/tests/app/job/test_rest.py index 5daa58595..f4dc01392 100644 --- a/tests/app/job/test_rest.py +++ b/tests/app/job/test_rest.py @@ -8,6 +8,7 @@ import pytz import app.celery.tasks from tests import create_authorization_header +from tests.conftest import set_config from tests.app.conftest import ( sample_job as create_job, sample_notification as create_notification @@ -612,11 +613,11 @@ def test_get_jobs_should_paginate( ): create_10_jobs(notify_db, notify_db_session, sample_template.service, sample_template) - client.application.config['PAGE_SIZE'] = 2 path = '/service/{}/job'.format(sample_template.service_id) auth_header = create_authorization_header(service_id=str(sample_template.service_id)) - response = client.get(path, headers=[auth_header]) + with set_config(client.application, 'PAGE_SIZE', 2): + response = client.get(path, headers=[auth_header]) assert response.status_code == 200 resp_json = json.loads(response.get_data(as_text=True)) @@ -637,11 +638,11 @@ def test_get_jobs_accepts_page_parameter( ): create_10_jobs(notify_db, notify_db_session, sample_template.service, sample_template) - client.application.config['PAGE_SIZE'] = 2 path = '/service/{}/job'.format(sample_template.service_id) auth_header = create_authorization_header(service_id=str(sample_template.service_id)) - response = client.get(path, headers=[auth_header], query_string={'page': 2}) + with set_config(client.application, 'PAGE_SIZE', 2): + response = client.get(path, headers=[auth_header], query_string={'page': 2}) assert response.status_code == 200 resp_json = json.loads(response.get_data(as_text=True)) diff --git a/tests/conftest.py b/tests/conftest.py index 3da2dead3..11f570de5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager import os import boto3 @@ -83,3 +84,11 @@ def pytest_generate_tests(metafunc): argnames, testdata = idparametrize.args ids, argvalues = zip(*sorted(testdata.items())) metafunc.parametrize(argnames, argvalues, ids=ids) + + +@contextmanager +def set_config(app, name, value): + old_val = app.config.get(name) + app.config[name] = value + yield + app.config[name] = old_val