From 93908bacda1a3948e1a03e30038991d434ee31c5 Mon Sep 17 00:00:00 2001 From: Rebecca Law Date: Tue, 13 Apr 2021 15:02:46 +0100 Subject: [PATCH] New strategy for transaction management. Introduce a contextmanger function to handle exceptions and nested transactions. Using the nested_transaction will start a nested transaction with `db.session.begin_nested`, once the nested transaction is complete the commit will happen. `@transactional` has been updated to commit unless in a nested transaction. --- app/billing/rest.py | 3 --- app/commands.py | 2 -- app/dao/annual_billing_dao.py | 4 ++-- app/dao/dao_utils.py | 31 +++++++++++++------------- app/dao/organisation_dao.py | 9 ++------ app/dao/services_dao.py | 9 ++------ app/organisation/rest.py | 10 +++------ app/service/rest.py | 11 +++------ tests/app/dao/test_organisation_dao.py | 2 +- 9 files changed, 29 insertions(+), 52 deletions(-) diff --git a/app/billing/rest.py b/app/billing/rest.py index 95c79798f..3f5f6e08c 100644 --- a/app/billing/rest.py +++ b/app/billing/rest.py @@ -1,6 +1,5 @@ from flask import Blueprint, jsonify, request -from app import db from app.billing.billing_schemas import ( create_or_update_free_sms_fragment_limit_schema, serialize_ft_billing_remove_emails, @@ -84,7 +83,6 @@ def get_free_sms_fragment_limit(service_id): annual_billing = dao_create_or_update_annual_billing_for_year(service_id, annual_billing.free_sms_fragment_limit, financial_year_start) - db.session.commit() return jsonify(annual_billing.serialize_free_sms_items()), 200 @@ -120,4 +118,3 @@ def update_free_sms_fragment_limit_data(service_id, free_sms_fragment_limit, fin free_sms_fragment_limit, financial_year_start ) - db.session.commit() diff --git a/app/commands.py b/app/commands.py index 12801b4a8..c98ce9bf3 100644 --- a/app/commands.py +++ b/app/commands.py @@ -880,7 +880,6 @@ def populate_annual_billing_with_the_previous_years_allowance(year): dao_create_or_update_annual_billing_for_year(service_id=row.id, free_sms_fragment_limit=free_allowance[0], financial_year_start=int(year)) - db.session.commit() @notify_command(name='populate-annual-billing-with-defaults') @@ -915,4 +914,3 @@ def populate_annual_billing_with_defaults(year, missing_services_only): for service in active_services: set_default_free_allowance_for_service(service, year) - db.session.commit() diff --git a/app/dao/annual_billing_dao.py b/app/dao/annual_billing_dao.py index 0ce4d0357..e1309b0c4 100644 --- a/app/dao/annual_billing_dao.py +++ b/app/dao/annual_billing_dao.py @@ -1,12 +1,12 @@ from flask import current_app from app import db -from app.dao.dao_utils import nested_transactional, transactional +from app.dao.dao_utils import transactional from app.dao.date_util import get_current_financial_year_start_year from app.models import AnnualBilling -@nested_transactional +@transactional def dao_create_or_update_annual_billing_for_year(service_id, free_sms_fragment_limit, financial_year_start): result = dao_get_free_sms_fragment_limit_for_year(service_id, financial_year_start) diff --git a/app/dao/dao_utils.py b/app/dao/dao_utils.py index 71c881bb3..0e9470b03 100644 --- a/app/dao/dao_utils.py +++ b/app/dao/dao_utils.py @@ -1,4 +1,5 @@ import itertools +from contextlib import contextmanager from functools import wraps from app import db @@ -10,7 +11,10 @@ def transactional(func): def commit_or_rollback(*args, **kwargs): try: res = func(*args, **kwargs) - db.session.commit() + + if not db.session.registry().transaction.nested: + db.session.commit() + return res except Exception: db.session.rollback() @@ -18,21 +22,18 @@ def transactional(func): return commit_or_rollback -def nested_transactional(func): - # This creates a save point for the nested transaction. - # You must manage the commit or rollback from outer most call of the nested of the transactions. - @wraps(func) - def commit_or_rollback(*args, **kwargs): - try: - db.session.begin_nested() - res = func(*args, **kwargs) - db.session.commit() - return res - except Exception: - db.session.rollback() - raise +@contextmanager +def nested_transaction(): + try: + db.session.begin_nested() + yield + db.session.commit() - return commit_or_rollback + if not db.session.registry().transaction.nested: + db.session.commit() + except Exception: + db.session.rollback() + raise class VersionOptions(): diff --git a/app/dao/organisation_dao.py b/app/dao/organisation_dao.py index e4b95d92f..877f8af30 100644 --- a/app/dao/organisation_dao.py +++ b/app/dao/organisation_dao.py @@ -1,12 +1,7 @@ from sqlalchemy.sql.expression import func from app import db -from app.dao.dao_utils import ( - VersionOptions, - nested_transactional, - transactional, - version_class, -) +from app.dao.dao_utils import VersionOptions, transactional, version_class from app.models import Domain, Organisation, Service, User @@ -110,7 +105,7 @@ def _update_organisation_services(organisation, attribute, only_where_none=True) db.session.add(service) -@nested_transactional +@transactional @version_class(Service) def dao_add_service_to_organisation(service, organisation_id): organisation = Organisation.query.filter_by( diff --git a/app/dao/services_dao.py b/app/dao/services_dao.py index 18071f5fa..c68092c64 100644 --- a/app/dao/services_dao.py +++ b/app/dao/services_dao.py @@ -7,12 +7,7 @@ from sqlalchemy.orm import joinedload from sqlalchemy.sql.expression import and_, asc, case, func from app import db -from app.dao.dao_utils import ( - VersionOptions, - nested_transactional, - transactional, - version_class, -) +from app.dao.dao_utils import VersionOptions, transactional, version_class from app.dao.date_util import get_current_financial_year from app.dao.email_branding_dao import dao_get_email_branding_by_name from app.dao.letter_branding_dao import dao_get_letter_branding_by_name @@ -289,7 +284,7 @@ def dao_fetch_service_by_id_and_user(service_id, user_id): ).one() -@nested_transactional +@transactional @version_class(Service) def dao_create_service( service, diff --git a/app/organisation/rest.py b/app/organisation/rest.py index 06b4ea83c..ec8d6aa25 100644 --- a/app/organisation/rest.py +++ b/app/organisation/rest.py @@ -1,10 +1,10 @@ from flask import Blueprint, abort, current_app, jsonify, request -from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.exc import IntegrityError -from app import db from app.config import QueueNames from app.dao.annual_billing_dao import set_default_free_allowance_for_service +from app.dao.dao_utils import nested_transaction from app.dao.fact_billing_dao import fetch_usage_year_for_organisation from app.dao.organisation_dao import ( dao_add_service_to_organisation, @@ -120,13 +120,9 @@ def link_service_to_organisation(organisation_id): service = dao_fetch_service_by_id(data['service_id']) service.organisation = None - try: + with nested_transaction(): dao_add_service_to_organisation(service, organisation_id) set_default_free_allowance_for_service(service, year_start=None) - db.session.commit() - except SQLAlchemyError as e: - db.session.rollback() - raise e return '', 204 diff --git a/app/service/rest.py b/app/service/rest.py index 9a6996a22..226ab3cc4 100644 --- a/app/service/rest.py +++ b/app/service/rest.py @@ -4,10 +4,9 @@ from datetime import datetime from flask import Blueprint, current_app, jsonify, request from notifications_utils.letter_timings import letter_can_be_cancelled from notifications_utils.timezones import convert_utc_to_bst -from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import NoResultFound -from app import db from app.aws import s3 from app.config import QueueNames from app.dao import fact_notification_status_dao, notifications_dao @@ -19,7 +18,7 @@ from app.dao.api_key_dao import ( save_model_api_key, ) from app.dao.broadcast_service_dao import set_broadcast_service_type -from app.dao.dao_utils import dao_rollback +from app.dao.dao_utils import dao_rollback, nested_transaction from app.dao.date_util import get_financial_year from app.dao.fact_notification_status_dao import ( fetch_monthly_template_usage_for_service, @@ -255,13 +254,9 @@ def create_service(): # unpack valid json into service object valid_service = Service.from_json(data) - try: + with nested_transaction(): dao_create_service(valid_service, user) set_default_free_allowance_for_service(service=valid_service, year_start=None) - db.session.commit() - except SQLAlchemyError as e: - db.session.rollback() - raise e return jsonify(data=service_schema.dump(valid_service).data), 201 diff --git a/tests/app/dao/test_organisation_dao.py b/tests/app/dao/test_organisation_dao.py index ed938ea61..91136de8a 100644 --- a/tests/app/dao/test_organisation_dao.py +++ b/tests/app/dao/test_organisation_dao.py @@ -233,7 +233,7 @@ def test_add_service_to_organisation(sample_service, sample_organisation): sample_organisation.crown = False dao_add_service_to_organisation(sample_service, sample_organisation.id) - db.session.commit() + assert len(sample_organisation.services) == 1 assert sample_organisation.services[0].id == sample_service.id