diff --git a/app/provider_details/rest.py b/app/provider_details/rest.py index b075653f5..4b7ae2a7a 100644 --- a/app/provider_details/rest.py +++ b/app/provider_details/rest.py @@ -37,9 +37,10 @@ def update_provider_details(provider_details_id): current_data.update(request.get_json()) update_dict = provider_details_schema.load(current_data).data - if "identifier" in request.get_json().keys(): + invalid_keys = {'identifier', 'version'} & set(key for key in request.get_json().keys()) + if invalid_keys: message = "Not permitted to be updated" - errors = {'identifier': [message]} + errors = {key: [message] for key in invalid_keys} raise InvalidRequest(errors, status_code=400) dao_update_provider_details(update_dict) diff --git a/tests/app/conftest.py b/tests/app/conftest.py index 674b99e31..8e570228b 100644 --- a/tests/app/conftest.py +++ b/tests/app/conftest.py @@ -1,6 +1,7 @@ import uuid from datetime import (datetime, date, timedelta) +from sqlalchemy.orm.session import make_transient import requests_mock import pytest from flask import current_app @@ -18,6 +19,8 @@ from app.models import ( Permission, ProviderStatistics, ProviderDetails, + ProviderDetailsHistory, + ProviderRates, NotificationStatistics, ServiceWhitelist, KEY_TYPE_NORMAL, KEY_TYPE_TEST, KEY_TYPE_TEAM, @@ -857,3 +860,33 @@ def sample_provider_rate(notify_db, notify_db_session, valid_from=None, rate=Non valid_from=valid_from if valid_from is not None else datetime.utcnow(), rate=rate if rate is not None else 1, ) + + +@pytest.fixture +def restore_provider_details(notify_db, notify_db_session): + """ + We view ProviderDetails as a static in notify_db_session, since we don't modify it... except we do, we updated + priority. This fixture is designed to be used in tests that will knowingly touch provider details, to restore them + to previous state. + + Note: This doesn't technically require notify_db_session (only notify_db), but kept as a requirement to encourage + good usage - if you're modifying ProviderDetails' state then it's good to clear down the rest of the DB too + """ + existing_provider_details = ProviderDetails.query.all() + existing_provider_details_history = ProviderDetailsHistory.query.all() + # make transient removes the objects from the session - since we'll want to delete them later + for epd in existing_provider_details: + make_transient(epd) + for epdh in existing_provider_details_history: + make_transient(epdh) + + yield + + # also delete these as they depend on provider_details + ProviderRates.query.delete() + ProviderDetails.query.delete() + ProviderDetailsHistory.query.delete() + notify_db.session.commit() + notify_db.session.add_all(existing_provider_details) + notify_db.session.add_all(existing_provider_details_history) + notify_db.session.commit() diff --git a/tests/app/provider_details/test_rest.py b/tests/app/provider_details/test_rest.py index b6a26a3e6..dc224e9b1 100644 --- a/tests/app/provider_details/test_rest.py +++ b/tests/app/provider_details/test_rest.py @@ -1,150 +1,96 @@ +import pytest from flask import json + +from app.models import ProviderDetails + from tests import create_authorization_header +def test_get_provider_details_in_type_and_identifier_order(client, notify_db): + response = client.get( + '/provider-details', + headers=[create_authorization_header()] + ) + assert response.status_code == 200 + json_resp = json.loads(response.get_data(as_text=True))['provider_details'] + assert len(json_resp) == 4 -def test_get_provider_details_in_type_and_identifier_order(notify_db, notify_db_session, notify_api): - with notify_api.test_request_context(): - with notify_api.test_client() as client: - auth_header = create_authorization_header() - response = client.get( - '/provider-details', - headers=[auth_header] - ) - assert response.status_code == 200 - json_resp = json.loads(response.get_data(as_text=True))['provider_details'] - assert len(json_resp) == 4 - - assert json_resp[0]['identifier'] == 'ses' - assert json_resp[1]['identifier'] == 'mmg' - assert json_resp[2]['identifier'] == 'firetext' - assert json_resp[3]['identifier'] == 'loadtesting' + assert json_resp[0]['identifier'] == 'ses' + assert json_resp[1]['identifier'] == 'mmg' + assert json_resp[2]['identifier'] == 'firetext' + assert json_resp[3]['identifier'] == 'loadtesting' -def test_get_provider_details_by_id(notify_db, notify_db_session, notify_api): - with notify_api.test_request_context(): - with notify_api.test_client() as client: - auth_header = create_authorization_header() - response = client.get( - '/provider-details', - headers=[auth_header] - ) - json_resp = json.loads(response.get_data(as_text=True))['provider_details'] +def test_get_provider_details_by_id(client, notify_db): + response = client.get( + '/provider-details', + headers=[create_authorization_header()] + ) + json_resp = json.loads(response.get_data(as_text=True))['provider_details'] - provider_resp = client.get( - '/provider-details/{}'.format(json_resp[0]['id']), - headers=[auth_header] - ) + provider_resp = client.get( + '/provider-details/{}'.format(json_resp[0]['id']), + headers=[create_authorization_header()] + ) - provider = json.loads(provider_resp.get_data(as_text=True))['provider_details'] - assert provider['identifier'] == json_resp[0]['identifier'] + provider = json.loads(provider_resp.get_data(as_text=True))['provider_details'] + assert provider['identifier'] == json_resp[0]['identifier'] -def test_get_provider_details_contains_correct_fields(notify_db, notify_db_session, notify_api): - with notify_api.test_request_context(): - with notify_api.test_client() as client: - auth_header = create_authorization_header() - response = client.get( - '/provider-details', - headers=[auth_header] - ) - json_resp = json.loads(response.get_data(as_text=True))['provider_details'] - allowed_keys = {"id", "display_name", "identifier", "priority", 'notification_type', "active"} - assert \ - allowed_keys == \ - set(json_resp[0].keys()) +def test_get_provider_details_contains_correct_fields(client, notify_db): + response = client.get( + '/provider-details', + headers=[create_authorization_header()] + ) + json_resp = json.loads(response.get_data(as_text=True))['provider_details'] + allowed_keys = {"id", "display_name", "identifier", "priority", 'notification_type', "active", "version"} + assert allowed_keys == set(json_resp[0].keys()) -def test_should_be_able_to_update_priority(notify_db, notify_db_session, notify_api): - with notify_api.test_request_context(): - with notify_api.test_client() as client: - auth_header = create_authorization_header() - response = client.get( - '/provider-details', - headers=[auth_header] - ) - fetch_resp = json.loads(response.get_data(as_text=True))['provider_details'] +def test_should_be_able_to_update_priority(client, restore_provider_details): + provider = ProviderDetails.query.first() - provider_id = fetch_resp[2]['id'] - - update_resp = client.post( - '/provider-details/{}'.format(provider_id), - headers=[('Content-Type', 'application/json'), auth_header], - data=json.dumps({ - 'priority': 5 - }) - ) - assert update_resp.status_code == 200 - update_json = json.loads(update_resp.get_data(as_text=True))['provider_details'] - assert update_json['identifier'] == 'firetext' - assert update_json['priority'] == 5 - - update_resp = client.post( - '/provider-details/{}'.format(provider_id), - headers=[('Content-Type', 'application/json'), auth_header], - data=json.dumps({ - 'priority': 20 - }) - ) - assert update_resp.status_code == 200 + update_resp = client.post( + '/provider-details/{}'.format(provider.id), + headers=[('Content-Type', 'application/json'), create_authorization_header()], + data=json.dumps({ + 'priority': 5 + }) + ) + assert update_resp.status_code == 200 + update_json = json.loads(update_resp.get_data(as_text=True))['provider_details'] + assert update_json['identifier'] == provider.identifier + assert update_json['priority'] == 5 + assert provider.priority == 5 -def test_should_be_able_to_update_status(notify_db, notify_db_session, notify_api): - with notify_api.test_request_context(): - with notify_api.test_client() as client: - auth_header = create_authorization_header() - response = client.get( - '/provider-details', - headers=[auth_header] - ) - fetch_resp = json.loads(response.get_data(as_text=True))['provider_details'] +def test_should_be_able_to_update_status(client, restore_provider_details): + provider = ProviderDetails.query.first() - firetext = next(x for x in fetch_resp if x['identifier'] == 'firetext') - - update_resp_1 = client.post( - '/provider-details/{}'.format(firetext['id']), - headers=[('Content-Type', 'application/json'), auth_header], - data=json.dumps({ - 'active': False - }) - ) - assert update_resp_1.status_code == 200 - update_resp_1 = json.loads(update_resp_1.get_data(as_text=True))['provider_details'] - assert update_resp_1['identifier'] == 'firetext' - assert not update_resp_1['active'] - - update_resp_2 = client.post( - '/provider-details/{}'.format(firetext['id']), - headers=[('Content-Type', 'application/json'), auth_header], - data=json.dumps({ - 'active': True - }) - ) - assert update_resp_2.status_code == 200 - update_resp_2 = json.loads(update_resp_2.get_data(as_text=True))['provider_details'] - assert update_resp_2['identifier'] == 'firetext' - assert update_resp_2['active'] + update_resp_1 = client.post( + '/provider-details/{}'.format(provider.id), + headers=[('Content-Type', 'application/json'), create_authorization_header()], + data=json.dumps({ + 'active': False + }) + ) + assert update_resp_1.status_code == 200 + update_resp_1 = json.loads(update_resp_1.get_data(as_text=True))['provider_details'] + assert update_resp_1['identifier'] == provider.identifier + assert not update_resp_1['active'] + assert not provider.active -def test_should_not_be_able_to_update_identifier(notify_db, notify_db_session, notify_api): - with notify_api.test_request_context(): - with notify_api.test_client() as client: - auth_header = create_authorization_header() - response = client.get( - '/provider-details', - headers=[auth_header] - ) - fetch_resp = json.loads(response.get_data(as_text=True))['provider_details'] +@pytest.mark.parametrize('field,value', [('identifier', 'new'), ('version', 7)]) +def test_should_not_be_able_to_update_disallowed_fields(client, restore_provider_details, field, value): + provider = ProviderDetails.query.first() - provider_id = fetch_resp[2]['id'] - - update_resp = client.post( - '/provider-details/{}'.format(provider_id), - headers=[('Content-Type', 'application/json'), auth_header], - data=json.dumps({ - 'identifier': "new" - }) - ) - assert update_resp.status_code == 400 - update_resp = json.loads(update_resp.get_data(as_text=True)) - assert update_resp['message']['identifier'][0] == 'Not permitted to be updated' - assert update_resp['result'] == 'error' + update_resp = client.post( + '/provider-details/{}'.format(provider.id), + headers=[('Content-Type', 'application/json'), create_authorization_header()], + data=json.dumps({field: value}) + ) + assert update_resp.status_code == 400 + update_resp = json.loads(update_resp.get_data(as_text=True)) + print(update_resp) + assert update_resp['message'][field][0] == 'Not permitted to be updated' + assert update_resp['result'] == 'error' diff --git a/tests/conftest.py b/tests/conftest.py index 17beb8a01..13e705880 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,7 +70,7 @@ def notify_db_session(notify_db): notify_db.session.remove() for tbl in reversed(notify_db.metadata.sorted_tables): - if tbl.name not in ["provider_details", "key_types", "branding_type", "job_status"]: + if tbl.name not in ["provider_details", "key_types", "branding_type", "job_status", "provider_details_history"]: notify_db.engine.execute(tbl.delete()) notify_db.session.commit()