diff --git a/app/billing/rest.py b/app/billing/rest.py index 3f5f6e08c..95c79798f 100644 --- a/app/billing/rest.py +++ b/app/billing/rest.py @@ -1,5 +1,6 @@ from flask import Blueprint, jsonify, request +from app import db from app.billing.billing_schemas import ( create_or_update_free_sms_fragment_limit_schema, serialize_ft_billing_remove_emails, @@ -83,6 +84,7 @@ def get_free_sms_fragment_limit(service_id): annual_billing = dao_create_or_update_annual_billing_for_year(service_id, annual_billing.free_sms_fragment_limit, financial_year_start) + db.session.commit() return jsonify(annual_billing.serialize_free_sms_items()), 200 @@ -118,3 +120,4 @@ def update_free_sms_fragment_limit_data(service_id, free_sms_fragment_limit, fin free_sms_fragment_limit, financial_year_start ) + db.session.commit() diff --git a/app/commands.py b/app/commands.py index c98ce9bf3..12801b4a8 100644 --- a/app/commands.py +++ b/app/commands.py @@ -880,6 +880,7 @@ def populate_annual_billing_with_the_previous_years_allowance(year): dao_create_or_update_annual_billing_for_year(service_id=row.id, free_sms_fragment_limit=free_allowance[0], financial_year_start=int(year)) + db.session.commit() @notify_command(name='populate-annual-billing-with-defaults') @@ -914,3 +915,4 @@ def populate_annual_billing_with_defaults(year, missing_services_only): for service in active_services: set_default_free_allowance_for_service(service, year) + db.session.commit() diff --git a/app/dao/annual_billing_dao.py b/app/dao/annual_billing_dao.py index 5caa46011..0ce4d0357 100644 --- a/app/dao/annual_billing_dao.py +++ b/app/dao/annual_billing_dao.py @@ -1,12 +1,12 @@ from flask import current_app from app import db -from app.dao.dao_utils import transactional +from app.dao.dao_utils import nested_transactional, transactional from app.dao.date_util import get_current_financial_year_start_year from app.models import AnnualBilling -@transactional +@nested_transactional def dao_create_or_update_annual_billing_for_year(service_id, free_sms_fragment_limit, financial_year_start): result = dao_get_free_sms_fragment_limit_for_year(service_id, financial_year_start) @@ -53,7 +53,7 @@ def dao_get_all_free_sms_fragment_limit(service_id): ).order_by(AnnualBilling.financial_year_start).all() -def set_default_free_allowance_for_service(service, year_start=None, commit=True): +def set_default_free_allowance_for_service(service, year_start=None): default_free_sms_fragment_limits = { 'central': { 2020: 250_000, diff --git a/app/dao/dao_utils.py b/app/dao/dao_utils.py index f5515d8b6..71c881bb3 100644 --- a/app/dao/dao_utils.py +++ b/app/dao/dao_utils.py @@ -18,6 +18,23 @@ def transactional(func): return commit_or_rollback +def nested_transactional(func): + # This creates a save point for the nested transaction. + # You must manage the commit or rollback from outer most call of the nested of the transactions. + @wraps(func) + def commit_or_rollback(*args, **kwargs): + try: + db.session.begin_nested() + res = func(*args, **kwargs) + db.session.commit() + return res + except Exception: + db.session.rollback() + raise + + return commit_or_rollback + + class VersionOptions(): def __init__(self, model_class, history_class=None, must_write_history=True): diff --git a/app/dao/organisation_dao.py b/app/dao/organisation_dao.py index 877f8af30..e4b95d92f 100644 --- a/app/dao/organisation_dao.py +++ b/app/dao/organisation_dao.py @@ -1,7 +1,12 @@ from sqlalchemy.sql.expression import func from app import db -from app.dao.dao_utils import VersionOptions, transactional, version_class +from app.dao.dao_utils import ( + VersionOptions, + nested_transactional, + transactional, + version_class, +) from app.models import Domain, Organisation, Service, User @@ -105,7 +110,7 @@ def _update_organisation_services(organisation, attribute, only_where_none=True) db.session.add(service) -@transactional +@nested_transactional @version_class(Service) def dao_add_service_to_organisation(service, organisation_id): organisation = Organisation.query.filter_by( diff --git a/app/dao/services_dao.py b/app/dao/services_dao.py index c68092c64..18071f5fa 100644 --- a/app/dao/services_dao.py +++ b/app/dao/services_dao.py @@ -7,7 +7,12 @@ from sqlalchemy.orm import joinedload from sqlalchemy.sql.expression import and_, asc, case, func from app import db -from app.dao.dao_utils import VersionOptions, transactional, version_class +from app.dao.dao_utils import ( + VersionOptions, + nested_transactional, + transactional, + version_class, +) from app.dao.date_util import get_current_financial_year from app.dao.email_branding_dao import dao_get_email_branding_by_name from app.dao.letter_branding_dao import dao_get_letter_branding_by_name @@ -284,7 +289,7 @@ def dao_fetch_service_by_id_and_user(service_id, user_id): ).one() -@transactional +@nested_transactional @version_class(Service) def dao_create_service( service, diff --git a/app/organisation/rest.py b/app/organisation/rest.py index d745a7adb..06b4ea83c 100644 --- a/app/organisation/rest.py +++ b/app/organisation/rest.py @@ -2,6 +2,7 @@ from flask import Blueprint, abort, current_app, jsonify, request from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from app import db from app.config import QueueNames from app.dao.annual_billing_dao import set_default_free_allowance_for_service from app.dao.fact_billing_dao import fetch_usage_year_for_organisation @@ -119,17 +120,13 @@ def link_service_to_organisation(organisation_id): service = dao_fetch_service_by_id(data['service_id']) service.organisation = None - dao_add_service_to_organisation(service, organisation_id) - # Need to do the annual billing update in a separate transaction because the both the - # dao_add_service_to_organisation and set_default_free_allowance_for_service are wrapped in a transaction. - # Catch and report an error if the annual billing doesn't happen - but don't rollback the service update. try: + dao_add_service_to_organisation(service, organisation_id) set_default_free_allowance_for_service(service, year_start=None) - except SQLAlchemyError: - # No need to worry about key errors because service.organisation_type has a foreign key to organisation_types - current_app.logger.exception( - f"Exception caught when trying to update annual billing when the organisation " - f"changed for service: {service.id} to organisation: {organisation_id}") + db.session.commit() + except SQLAlchemyError as e: + db.session.rollback() + raise e return '', 204 diff --git a/app/service/rest.py b/app/service/rest.py index 03d111be9..9a6996a22 100644 --- a/app/service/rest.py +++ b/app/service/rest.py @@ -7,6 +7,7 @@ from notifications_utils.timezones import convert_utc_to_bst from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.orm.exc import NoResultFound +from app import db from app.aws import s3 from app.config import QueueNames from app.dao import fact_notification_status_dao, notifications_dao @@ -254,18 +255,13 @@ def create_service(): # unpack valid json into service object valid_service = Service.from_json(data) - dao_create_service(valid_service, user) - - # Need to do the annual billing update in a separate transaction because the both the - # dao_add_service_to_organisation and set_default_free_allowance_for_service are wrapped in a transaction. - # Catch and report an error if the annual billing doesn't happen - but don't rollback the service update. try: - set_default_free_allowance_for_service(valid_service, year_start=None) - except SQLAlchemyError: - # No need to worry about key errors because service.organisation_type has a foreign key to organisation_types - current_app.logger.exception( - f"Exception caught when trying to insert annual billing creating a service {valid_service.id} " - f"for organisation_type {valid_service.organisation_type}") + dao_create_service(valid_service, user) + set_default_free_allowance_for_service(service=valid_service, year_start=None) + db.session.commit() + except SQLAlchemyError as e: + db.session.rollback() + raise e return jsonify(data=service_schema.dump(valid_service).data), 201 diff --git a/tests/app/dao/test_organisation_dao.py b/tests/app/dao/test_organisation_dao.py index 91136de8a..ed938ea61 100644 --- a/tests/app/dao/test_organisation_dao.py +++ b/tests/app/dao/test_organisation_dao.py @@ -233,7 +233,7 @@ def test_add_service_to_organisation(sample_service, sample_organisation): sample_organisation.crown = False dao_add_service_to_organisation(sample_service, sample_organisation.id) - + db.session.commit() assert len(sample_organisation.services) == 1 assert sample_organisation.services[0].id == sample_service.id diff --git a/tests/app/organisation/test_rest.py b/tests/app/organisation/test_rest.py index 5b2def4e4..2c3e3d0bb 100644 --- a/tests/app/organisation/test_rest.py +++ b/tests/app/organisation/test_rest.py @@ -505,14 +505,54 @@ def test_post_link_service_to_organisation(admin_request, sample_service): organisation_id=organisation.id, _expected_status=204 ) - assert len(organisation.services) == 1 assert sample_service.organisation_type == 'central' + + +def test_post_link_service_to_organisation_inserts_annual_billing(admin_request, sample_service): + data = { + 'service_id': str(sample_service.id) + } + organisation = create_organisation(organisation_type='central') + assert len(organisation.services) == 0 + assert len(AnnualBilling.query.all()) == 0 + admin_request.post( + 'organisation.link_service_to_organisation', + _data=data, + organisation_id=organisation.id, + _expected_status=204 + ) + annual_billing = AnnualBilling.query.all() assert len(annual_billing) == 1 assert annual_billing[0].free_sms_fragment_limit == 150000 +def test_post_link_service_to_organisation_rollback_service_if_annual_billing_update_fails( + admin_request, sample_service, mocker +): + mocker.patch('app.dao.annual_billing_dao.dao_create_or_update_annual_billing_for_year', + side_effect=SQLAlchemyError) + data = { + 'service_id': str(sample_service.id) + } + assert not sample_service.organisation_type + + organisation = create_organisation(organisation_type='central') + assert len(organisation.services) == 0 + assert len(AnnualBilling.query.all()) == 0 + with pytest.raises(expected_exception=SQLAlchemyError): + admin_request.post( + 'organisation.link_service_to_organisation', + _data=data, + organisation_id=organisation.id, + _expected_status=404 + ) + assert not sample_service.organisation_type + assert len(organisation.services) == 0 + assert len(AnnualBilling.query.all()) == 0 + + def test_post_link_service_to_another_org( admin_request, sample_service, sample_organisation): data = { @@ -582,23 +622,6 @@ def test_post_link_service_to_organisation_missing_payload( ) -def test_link_service_to_organisation_updates_service_if_annual_billing_update_fails( - mocker, admin_request, sample_service, sample_organisation -): - mocker.patch('app.organisation.rest.set_default_free_allowance_for_service', raises=SQLAlchemyError) - data = { - 'service_id': str(sample_service.id) - } - admin_request.post( - 'organisation.link_service_to_organisation', - organisation_id=str(sample_organisation.id), - _data=data, - _expected_status=204 - ) - assert sample_service.organisation_id == sample_organisation.id - assert len(AnnualBilling.query.all()) == 0 - - def test_rest_get_organisation_services( admin_request, sample_organisation, sample_service): dao_add_service_to_organisation(sample_service, sample_organisation.id) diff --git a/tests/app/service/test_rest.py b/tests/app/service/test_rest.py index 6e2599a45..7484d6203 100644 --- a/tests/app/service/test_rest.py +++ b/tests/app/service/test_rest.py @@ -503,10 +503,10 @@ def test_create_service_should_create_annual_billing_for_service( assert len(annual_billing) == 1 -def test_create_service_should_create_service_if_annual_billing_query_fails( +def test_create_service_should_raise_exception_and_not_create_service_if_annual_billing_query_fails( admin_request, sample_user, mocker ): - mocker.patch('app.service.rest.set_default_free_allowance_for_service', raises=SQLAlchemyError) + mocker.patch('app.service.rest.set_default_free_allowance_for_service', side_effect=SQLAlchemyError) data = { 'name': 'created service', 'user_id': str(sample_user.id), @@ -517,11 +517,12 @@ def test_create_service_should_create_service_if_annual_billing_query_fails( 'created_by': str(sample_user.id) } assert len(AnnualBilling.query.all()) == 0 - admin_request.post('service.create_service', _data=data, _expected_status=201) + with pytest.raises(expected_exception=SQLAlchemyError): + admin_request.post('service.create_service', _data=data) annual_billing = AnnualBilling.query.all() assert len(annual_billing) == 0 - assert len(Service.query.filter(Service.name == 'created service').all()) == 1 + assert len(Service.query.filter(Service.name == 'created service').all()) == 0 def test_create_service_inherits_branding_from_organisation(