From cf351356052c78432faf1a53f0b53d4cae5501ae Mon Sep 17 00:00:00 2001 From: Rebecca Law Date: Mon, 12 Apr 2021 13:52:40 +0100 Subject: [PATCH] Adding @nested_transactional for transactions that require more than one db update/insert. Using a savepoint for the multiple transactions allows us to rollback if there is an error when executing the second db transaction. However, this does add a bit of complexity. Developers need to manage the db session when calling multiple nested tranactions. Unit tests have been added to test this functionality and some end to end tests have been done to make sure all transactions are rollback if there is an exception while executing the transaction. --- app/billing/rest.py | 3 ++ app/commands.py | 2 + app/dao/annual_billing_dao.py | 6 +-- app/dao/dao_utils.py | 17 ++++++++ app/dao/organisation_dao.py | 9 +++- app/dao/services_dao.py | 9 +++- app/organisation/rest.py | 15 +++---- app/service/rest.py | 18 +++----- tests/app/dao/test_organisation_dao.py | 2 +- tests/app/organisation/test_rest.py | 59 ++++++++++++++++++-------- tests/app/service/test_rest.py | 9 ++-- 11 files changed, 99 insertions(+), 50 deletions(-) diff --git a/app/billing/rest.py b/app/billing/rest.py index 3f5f6e08c..95c79798f 100644 --- a/app/billing/rest.py +++ b/app/billing/rest.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from app import db from app.billing.billing_schemas import ( create_or_update_free_sms_fragment_limit_schema, serialize_ft_billing_remove_emails, @@ -83,6 +84,7 @@ def get_free_sms_fragment_limit(service_id): annual_billing = dao_create_or_update_annual_billing_for_year(service_id, annual_billing.free_sms_fragment_limit, financial_year_start) + db.session.commit() return jsonify(annual_billing.serialize_free_sms_items()), 200 @@ -118,3 +120,4 @@ def update_free_sms_fragment_limit_data(service_id, free_sms_fragment_limit, fin free_sms_fragment_limit, financial_year_start ) + db.session.commit() diff --git a/app/commands.py b/app/commands.py index c98ce9bf3..12801b4a8 100644 --- a/app/commands.py +++ b/app/commands.py @@ -880,6 +880,7 @@ def populate_annual_billing_with_the_previous_years_allowance(year): dao_create_or_update_annual_billing_for_year(service_id=row.id, free_sms_fragment_limit=free_allowance[0], financial_year_start=int(year)) + db.session.commit() @notify_command(name='populate-annual-billing-with-defaults') @@ -914,3 +915,4 @@ def populate_annual_billing_with_defaults(year, missing_services_only): for service in active_services: set_default_free_allowance_for_service(service, year) + db.session.commit() diff --git a/app/dao/annual_billing_dao.py b/app/dao/annual_billing_dao.py index 5caa46011..0ce4d0357 100644 --- a/app/dao/annual_billing_dao.py +++ b/app/dao/annual_billing_dao.py @@ -1,12 +1,12 @@ from flask import current_app from app import db -from app.dao.dao_utils import transactional +from app.dao.dao_utils import nested_transactional, transactional from app.dao.date_util import get_current_financial_year_start_year from app.models import AnnualBilling -@transactional +@nested_transactional def dao_create_or_update_annual_billing_for_year(service_id, free_sms_fragment_limit, financial_year_start): result = dao_get_free_sms_fragment_limit_for_year(service_id, financial_year_start) @@ -53,7 +53,7 @@ def dao_get_all_free_sms_fragment_limit(service_id): ).order_by(AnnualBilling.financial_year_start).all() -def set_default_free_allowance_for_service(service, year_start=None, commit=True): +def set_default_free_allowance_for_service(service, year_start=None): default_free_sms_fragment_limits = { 'central': { 2020: 250_000, diff --git a/app/dao/dao_utils.py b/app/dao/dao_utils.py index f5515d8b6..71c881bb3 100644 --- a/app/dao/dao_utils.py +++ b/app/dao/dao_utils.py @@ -18,6 +18,23 @@ def transactional(func): return commit_or_rollback +def nested_transactional(func): + # This creates a save point for the nested transaction. + # You must manage the commit or rollback from outer most call of the nested of the transactions. + @wraps(func) + def commit_or_rollback(*args, **kwargs): + try: + db.session.begin_nested() + res = func(*args, **kwargs) + db.session.commit() + return res + except Exception: + db.session.rollback() + raise + + return commit_or_rollback + + class VersionOptions(): def __init__(self, model_class, history_class=None, must_write_history=True): diff --git a/app/dao/organisation_dao.py b/app/dao/organisation_dao.py index 877f8af30..e4b95d92f 100644 --- a/app/dao/organisation_dao.py +++ b/app/dao/organisation_dao.py @@ -1,7 +1,12 @@ from sqlalchemy.sql.expression import func from app import db -from app.dao.dao_utils import VersionOptions, transactional, version_class +from app.dao.dao_utils import ( + VersionOptions, + nested_transactional, + transactional, + version_class, +) from app.models import Domain, Organisation, Service, User @@ -105,7 +110,7 @@ def _update_organisation_services(organisation, attribute, only_where_none=True) db.session.add(service) -@transactional +@nested_transactional @version_class(Service) def dao_add_service_to_organisation(service, organisation_id): organisation = Organisation.query.filter_by( diff --git a/app/dao/services_dao.py b/app/dao/services_dao.py index c68092c64..18071f5fa 100644 --- a/app/dao/services_dao.py +++ b/app/dao/services_dao.py @@ -7,7 +7,12 @@ from sqlalchemy.orm import joinedload from sqlalchemy.sql.expression import and_, asc, case, func from app import db -from app.dao.dao_utils import VersionOptions, transactional, version_class +from app.dao.dao_utils import ( + VersionOptions, + nested_transactional, + transactional, + version_class, +) from app.dao.date_util import get_current_financial_year from app.dao.email_branding_dao import dao_get_email_branding_by_name from app.dao.letter_branding_dao import dao_get_letter_branding_by_name @@ -284,7 +289,7 @@ def dao_fetch_service_by_id_and_user(service_id, user_id): ).one() -@transactional +@nested_transactional @version_class(Service) def dao_create_service( service, diff --git a/app/organisation/rest.py b/app/organisation/rest.py index d745a7adb..06b4ea83c 100644 --- a/app/organisation/rest.py +++ b/app/organisation/rest.py @@ -2,6 +2,7 @@ from flask import Blueprint, abort, current_app, jsonify, request from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from app import db from app.config import QueueNames from app.dao.annual_billing_dao import set_default_free_allowance_for_service from app.dao.fact_billing_dao import fetch_usage_year_for_organisation @@ -119,17 +120,13 @@ def link_service_to_organisation(organisation_id): service = dao_fetch_service_by_id(data['service_id']) service.organisation = None - dao_add_service_to_organisation(service, organisation_id) - # Need to do the annual billing update in a separate transaction because the both the - # dao_add_service_to_organisation and set_default_free_allowance_for_service are wrapped in a transaction. - # Catch and report an error if the annual billing doesn't happen - but don't rollback the service update. try: + dao_add_service_to_organisation(service, organisation_id) set_default_free_allowance_for_service(service, year_start=None) - except SQLAlchemyError: - # No need to worry about key errors because service.organisation_type has a foreign key to organisation_types - current_app.logger.exception( - f"Exception caught when trying to update annual billing when the organisation " - f"changed for service: {service.id} to organisation: {organisation_id}") + db.session.commit() + except SQLAlchemyError as e: + db.session.rollback() + raise e return '', 204 diff --git a/app/service/rest.py b/app/service/rest.py index 03d111be9..9a6996a22 100644 --- a/app/service/rest.py +++ b/app/service/rest.py @@ -7,6 +7,7 @@ from notifications_utils.timezones import convert_utc_to_bst from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.orm.exc import NoResultFound +from app import db from app.aws import s3 from app.config import QueueNames from app.dao import fact_notification_status_dao, notifications_dao @@ -254,18 +255,13 @@ def create_service(): # unpack valid json into service object valid_service = Service.from_json(data) - dao_create_service(valid_service, user) - - # Need to do the annual billing update in a separate transaction because the both the - # dao_add_service_to_organisation and set_default_free_allowance_for_service are wrapped in a transaction. - # Catch and report an error if the annual billing doesn't happen - but don't rollback the service update. try: - set_default_free_allowance_for_service(valid_service, year_start=None) - except SQLAlchemyError: - # No need to worry about key errors because service.organisation_type has a foreign key to organisation_types - current_app.logger.exception( - f"Exception caught when trying to insert annual billing creating a service {valid_service.id} " - f"for organisation_type {valid_service.organisation_type}") + dao_create_service(valid_service, user) + set_default_free_allowance_for_service(service=valid_service, year_start=None) + db.session.commit() + except SQLAlchemyError as e: + db.session.rollback() + raise e return jsonify(data=service_schema.dump(valid_service).data), 201 diff --git a/tests/app/dao/test_organisation_dao.py b/tests/app/dao/test_organisation_dao.py index 91136de8a..ed938ea61 100644 --- a/tests/app/dao/test_organisation_dao.py +++ b/tests/app/dao/test_organisation_dao.py @@ -233,7 +233,7 @@ def test_add_service_to_organisation(sample_service, sample_organisation): sample_organisation.crown = False dao_add_service_to_organisation(sample_service, sample_organisation.id) - + db.session.commit() assert len(sample_organisation.services) == 1 assert sample_organisation.services[0].id == sample_service.id diff --git a/tests/app/organisation/test_rest.py b/tests/app/organisation/test_rest.py index 5b2def4e4..2c3e3d0bb 100644 --- a/tests/app/organisation/test_rest.py +++ b/tests/app/organisation/test_rest.py @@ -505,14 +505,54 @@ def test_post_link_service_to_organisation(admin_request, sample_service): organisation_id=organisation.id, _expected_status=204 ) - assert len(organisation.services) == 1 assert sample_service.organisation_type == 'central' + + +def test_post_link_service_to_organisation_inserts_annual_billing(admin_request, sample_service): + data = { + 'service_id': str(sample_service.id) + } + organisation = create_organisation(organisation_type='central') + assert len(organisation.services) == 0 + assert len(AnnualBilling.query.all()) == 0 + admin_request.post( + 'organisation.link_service_to_organisation', + _data=data, + organisation_id=organisation.id, + _expected_status=204 + ) + annual_billing = AnnualBilling.query.all() assert len(annual_billing) == 1 assert annual_billing[0].free_sms_fragment_limit == 150000 +def test_post_link_service_to_organisation_rollback_service_if_annual_billing_update_fails( + admin_request, sample_service, mocker +): + mocker.patch('app.dao.annual_billing_dao.dao_create_or_update_annual_billing_for_year', + side_effect=SQLAlchemyError) + data = { + 'service_id': str(sample_service.id) + } + assert not sample_service.organisation_type + + organisation = create_organisation(organisation_type='central') + assert len(organisation.services) == 0 + assert len(AnnualBilling.query.all()) == 0 + with pytest.raises(expected_exception=SQLAlchemyError): + admin_request.post( + 'organisation.link_service_to_organisation', + _data=data, + organisation_id=organisation.id, + _expected_status=404 + ) + assert not sample_service.organisation_type + assert len(organisation.services) == 0 + assert len(AnnualBilling.query.all()) == 0 + + def test_post_link_service_to_another_org( admin_request, sample_service, sample_organisation): data = { @@ -582,23 +622,6 @@ def test_post_link_service_to_organisation_missing_payload( ) -def test_link_service_to_organisation_updates_service_if_annual_billing_update_fails( - mocker, admin_request, sample_service, sample_organisation -): - mocker.patch('app.organisation.rest.set_default_free_allowance_for_service', raises=SQLAlchemyError) - data = { - 'service_id': str(sample_service.id) - } - admin_request.post( - 'organisation.link_service_to_organisation', - organisation_id=str(sample_organisation.id), - _data=data, - _expected_status=204 - ) - assert sample_service.organisation_id == sample_organisation.id - assert len(AnnualBilling.query.all()) == 0 - - def test_rest_get_organisation_services( admin_request, sample_organisation, sample_service): dao_add_service_to_organisation(sample_service, sample_organisation.id) diff --git a/tests/app/service/test_rest.py b/tests/app/service/test_rest.py index 6e2599a45..7484d6203 100644 --- a/tests/app/service/test_rest.py +++ b/tests/app/service/test_rest.py @@ -503,10 +503,10 @@ def test_create_service_should_create_annual_billing_for_service( assert len(annual_billing) == 1 -def test_create_service_should_create_service_if_annual_billing_query_fails( +def test_create_service_should_raise_exception_and_not_create_service_if_annual_billing_query_fails( admin_request, sample_user, mocker ): - mocker.patch('app.service.rest.set_default_free_allowance_for_service', raises=SQLAlchemyError) + mocker.patch('app.service.rest.set_default_free_allowance_for_service', side_effect=SQLAlchemyError) data = { 'name': 'created service', 'user_id': str(sample_user.id), @@ -517,11 +517,12 @@ def test_create_service_should_create_service_if_annual_billing_query_fails( 'created_by': str(sample_user.id) } assert len(AnnualBilling.query.all()) == 0 - admin_request.post('service.create_service', _data=data, _expected_status=201) + with pytest.raises(expected_exception=SQLAlchemyError): + admin_request.post('service.create_service', _data=data) annual_billing = AnnualBilling.query.all() assert len(annual_billing) == 0 - assert len(Service.query.filter(Service.name == 'created service').all()) == 1 + assert len(Service.query.filter(Service.name == 'created service').all()) == 0 def test_create_service_inherits_branding_from_organisation(