diff --git a/app/dao/services_dao.py b/app/dao/services_dao.py index 604f0dd25..bead23a6e 100644 --- a/app/dao/services_dao.py +++ b/app/dao/services_dao.py @@ -75,6 +75,7 @@ def dao_fetch_all_services_by_user(user_id, only_active=False): @version_class(ApiKey) def dao_deactive_service(service_id): # have to eager load templates and api keys so that we don't flush when we loop through them + # to ensure that db.session still contains the models when it comes to creating history objects service = Service.query.options( joinedload('templates'), joinedload('api_keys'), @@ -84,17 +85,12 @@ def dao_deactive_service(service_id): service.name = '_archived_' + service.name service.email_from = '_archived_' + service.email_from - for template in service.templates: template.archived = True for api_key in service.api_keys: api_key.expiry_date = datetime.utcnow() - db.session.add(service) - db.session.add_all(service.templates) - db.session.add_all(service.api_keys) - def dao_fetch_service_by_id_and_user(service_id, user_id): return Service.query.filter( diff --git a/app/models.py b/app/models.py index 2f2c6163d..699ae2275 100644 --- a/app/models.py +++ b/app/models.py @@ -5,7 +5,7 @@ from sqlalchemy.dialects.postgresql import ( UUID, JSON ) -from sqlalchemy import UniqueConstraint, and_, false, true +from sqlalchemy import UniqueConstraint, and_ from sqlalchemy.orm import foreign, remote from notifications_utils.recipients import ( validate_email_address, diff --git a/tests/__init__.py b/tests/__init__.py index 526e0efc1..a40d209cf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -30,3 +30,12 @@ def create_authorization_header(service_id=None, key_type=KEY_TYPE_NORMAL): token = create_jwt_token(secret=secret, client_id=client_id) return 'Authorization', 'Bearer {}'.format(token) + + +def unwrap_function(fn): + """ + Given a function, returns its undecorated original. + """ + while hasattr(fn, '__wrapped__'): + fn = fn.__wrapped__ + return fn diff --git a/tests/app/service/test_deactivate.py b/tests/app/service/test_deactivate.py index df39fecac..0f9a3dcce 100644 --- a/tests/app/service/test_deactivate.py +++ b/tests/app/service/test_deactivate.py @@ -1,22 +1,26 @@ import uuid +from unittest import mock import pytest -from app.models import Service -from tests import create_authorization_header +from app import db +from app.models import Service, TemplateHistory, ApiKey +from app.dao.services_dao import dao_deactive_service + +from tests import create_authorization_header, unwrap_function from tests.app.conftest import ( sample_template as create_template, sample_api_key as create_api_key ) -def test_deactivate_only_allows_post(client, sample_service): +def test_deactivate_only_allows_post(client): auth_header = create_authorization_header() response = client.get('/service/{}/deactivate'.format(uuid.uuid4()), headers=[auth_header]) assert response.status_code == 405 -def test_deactivate_service_errors_with_bad_service_id(client, sample_service): +def test_deactivate_service_errors_with_bad_service_id(client): auth_header = create_authorization_header() response = client.post('/service/{}/deactivate'.format(uuid.uuid4()), headers=[auth_header]) assert response.status_code == 404 @@ -73,3 +77,13 @@ def test_deactivating_service_creates_history(deactivated_service): assert history.version == 2 assert history.active is False + + +def test_deactivating_service_rolls_back_everything_on_error(sample_service, sample_api_key, sample_template): + unwrapped_deactive_service = unwrap_function(dao_deactive_service) + + unwrapped_deactive_service(sample_service.id) + + assert sample_service in db.session.dirty + assert sample_api_key in db.session.dirty + assert sample_template in db.session.dirty